예제 #1
0
def _task_fn(index, driver_addresses, settings):
    task = task_service.SparkTaskService(index, settings.key)
    try:
        driver_client = driver_service.SparkDriverClient(driver_addresses, settings.key, settings.verbose)
        driver_client.register_task(index, task.addresses(), host_hash.host_hash())
        task.wait_for_initial_registration(settings.timeout)
        # Tasks ping each other in a circular fashion to determine interfaces reachable within
        # the cluster.
        next_task_index = (index + 1) % settings.num_proc
        next_task_addresses = driver_client.all_task_addresses(next_task_index)
        # We request interface matching to weed out all the NAT'ed interfaces.
        next_task_client = \
            task_service.SparkTaskClient(next_task_index, next_task_addresses,
                                         settings.key, settings.verbose,
                                         match_intf=True)
        driver_client.register_task_to_task_addresses(next_task_index, next_task_client.addresses())
        task_indices_on_this_host = driver_client.task_host_hash_indices(
            host_hash.host_hash())
        if task_indices_on_this_host[0] == index:
            # Task with first index will execute orted that will run mpirun_exec_fn for all tasks.
            task.wait_for_command_start(settings.timeout)
            task.wait_for_command_termination()
        else:
            # The rest of tasks need to wait for the first task to finish.
            first_task_addresses = driver_client.all_task_addresses(task_indices_on_this_host[0])
            first_task_client = \
                task_service.SparkTaskClient(task_indices_on_this_host[0],
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            first_task_client.wait_for_command_termination()
        return task.fn_result()
    finally:
        task.shutdown()
예제 #2
0
파일: __init__.py 프로젝트: zpcalan/horovod
def task_exec(driver_addresses, settings, rank_env, local_rank_env):
    # Die if parent process terminates
    in_thread(target=_parent_process_monitor, args=(os.getppid(), ))

    key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY])
    rank = int(os.environ[rank_env])
    local_rank = int(os.environ[local_rank_env])
    driver_client = driver_service.SparkDriverClient(driver_addresses,
                                                     key,
                                                     verbose=settings.verbose)

    # tell driver about local rank and rank
    # in elastic mode the driver already knows this mapping
    # for simplicity we keep code paths the same for elastic and static mode
    host_hash = os.environ['HOROVOD_HOSTNAME']
    task_index = driver_client.set_local_rank_to_rank(host_hash, local_rank,
                                                      rank)

    # gather available resources from task service
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index,
                                               task_addresses,
                                               key,
                                               verbose=settings.verbose)
    task_info.set_resources(task_client.resources())

    fn, args, kwargs = driver_client.code()
    result = fn(*args, **kwargs)
    task_client.register_code_result(result)
예제 #3
0
def rsh(driver_addresses, settings, host_hash, command, env, local_rank):
    """
    Method to run a command remotely given a host hash, local rank and driver addresses.

    This method connects to the SparkDriverService running on the Spark driver,
    retrieves all information required to connect to the task with given local rank
    of that host hash and invoke the command there.

    :param driver_addresses: driver's addresses
    :param settings: settings
    :param host_hash: host hash to connect to
    :param command: command and arguments to invoke
    :param env: environment to use
    :param local_rank: local rank on the host of task to run the command in
    """
    if ':' in host_hash:
        raise Exception('Illegal host hash provided. Are you using Open MPI 4.0.0+?')

    key = codec.loads_base64(env[secret.HOROVOD_SECRET_KEY])
    driver_client = driver_service.SparkDriverClient(driver_addresses, key,
                                                     verbose=settings.verbose)
    task_indices = driver_client.task_host_hash_indices(host_hash)
    task_index = task_indices[local_rank]
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index, task_addresses,
                                               key, verbose=settings.verbose)
    task_client.run_command(command, env)
예제 #4
0
def main(driver_addresses, settings, host_hash, command):
    """
    Method to run `orted` remotely given a host hash and driver addresses.

    This method connects to the SparkDriverService running on the Spark driver,
    retrieves all information required to connect to the task with the lowest task index
    of that host hash and invoke the command there.
    All other tasks with the same host hash are expected to no-op (see `horovod.spark._task_fn`)
    and wait for the first task to terminate.

    :param driver_addresses: driver's addresses
    :param settings: settings
    :param host_hash: host hash to connect to
    :param command: command and arguments to invoke
    """
    if ':' in host_hash:
        raise Exception(
            'Illegal host hash provided. Are you using Open MPI 4.0.0+?')

    key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY])
    driver_client = driver_service.SparkDriverClient(driver_addresses,
                                                     key,
                                                     verbose=settings.verbose)
    task_indices = driver_client.task_host_hash_indices(host_hash)
    # Since tasks with the same host hash have shared memory, we will run only
    # one ORTED process on the first task.
    first_task_index = task_indices[0]
    task_addresses = driver_client.all_task_addresses(first_task_index)
    task_client = task_service.SparkTaskClient(first_task_index,
                                               task_addresses,
                                               key,
                                               verbose=settings.verbose)
    task_client.run_command(command, os.environ)
예제 #5
0
파일: runner.py 프로젝트: yunzhezyz/horovod
def _task_fn(index, driver_addresses, key, settings, use_gloo):
    # deserialized on Spark workers, settings do not contain the key, so it is given here explicitly
    # Spark RPC communicates the key and supports encryption
    # for convenience, we put it back into settings
    settings.key = key

    task = task_service.SparkTaskService(index, settings.key, settings.nics,
                                         settings.verbose)
    try:
        driver_client = driver_service.SparkDriverClient(
            driver_addresses, settings.key, settings.verbose)
        driver_client.register_task(index, task.addresses(),
                                    host_hash.host_hash())
        task.wait_for_initial_registration(settings.timeout)
        task_indices_on_this_host = driver_client.task_host_hash_indices(
            host_hash.host_hash())

        # With Gloo all tasks wait for the command
        # With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks.
        if use_gloo or task_indices_on_this_host[0] == index:
            task.wait_for_command_start(settings.timeout)
            task.wait_for_command_termination()
        else:
            # The rest of tasks need to wait for the first task to finish.
            first_task_addresses = driver_client.all_task_addresses(
                task_indices_on_this_host[0])
            first_task_client = \
                task_service.SparkTaskClient(task_indices_on_this_host[0],
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            first_task_client.wait_for_command_termination()
        return task.fn_result()
    finally:
        task.shutdown()
예제 #6
0
def rsh(driver_addresses, key, host_hash, command, env, local_rank, verbose,
        stdout=None, stderr=None, prefix_output_with_timestamp=False,
        background=True, events=None):
    """
    Method to run a command remotely given a host hash, local rank and driver addresses.

    This method connects to the SparkDriverService running on the Spark driver,
    retrieves all information required to connect to the task with given local rank
    of that host hash and invoke the command there.

    The method returns immediately after launching the command if background is True (default).
    When background is set to False, this method waits for command termination and returns
    command's result. If there is an exception while waiting for the result (i.e. connection reset)
    it returns -1.

    :param driver_addresses: driver's addresses
    :param key: used for encryption of parameters passed across the hosts
    :param host_hash: host hash to connect to
    :param command: command and arguments to invoke
    :param env: environment to use
    :param local_rank: local rank on the host of task to run the command in
    :param verbose: verbosity level
    :param stdout: Task stdout is redirected to this stream.
    :param stderr: Task stderr is redirected to this stream.
    :param prefix_output_with_timestamp: shows timestamp in stdout/stderr forwarding on the driver if True
    :param background: run command in background if True, returns command result otherwise
    :param events: events to abort the command, only if background is True
    :return exit code if background is False
    """
    if ':' in host_hash:
        raise Exception('Illegal host hash provided. Are you using Open MPI 4.0.0+?')

    driver_client = driver_service.SparkDriverClient(driver_addresses, key, verbose=verbose)
    task_indices = driver_client.task_host_hash_indices(host_hash)
    task_index = task_indices[local_rank]
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index, task_addresses, key, verbose=verbose)
    task_client.stream_command_output(stdout, stderr)
    task_client.run_command(command, env,
                            capture_stdout=stdout is not None,
                            capture_stderr=stderr is not None,
                            prefix_output_with_timestamp=prefix_output_with_timestamp)

    if not background:
        events = events or []
        stop = threading.Event()
        for event in events:
            on_event(event, task_client.abort_command, stop=stop)

        try:
            exit_code = task_client.wait_for_command_exit_code()
            logging.debug('rsh exit code %s for host %s slot %s', exit_code, host_hash, local_rank)
            return exit_code
        except:
            traceback.print_exc()
            return -1
        finally:
            stop.set()
예제 #7
0
 def notify_and_register(index):
     task_client = task_service.SparkTaskClient(
         index, driver.task_addresses_for_driver(index), settings.key,
         settings.verbose)
     task_client.notify_initial_registration_complete()
     next_task_index = (index + 1) % settings.num_proc
     next_task_addresses = driver.all_task_addresses(next_task_index)
     task_to_task_addresses = task_client.get_task_addresses_for_task(
         next_task_index, next_task_addresses)
     driver.register_task_to_task_addresses(next_task_index,
                                            task_to_task_addresses)
예제 #8
0
파일: rsh.py 프로젝트: zhaocq-nlp/horovod
def rsh(driver_addresses, key, settings, host_hash, command, env, local_rank,
        background=True, events=None):
    """
    Method to run a command remotely given a host hash, local rank and driver addresses.

    This method connects to the SparkDriverService running on the Spark driver,
    retrieves all information required to connect to the task with given local rank
    of that host hash and invoke the command there.

    The method returns immediately after launching the command if background is True (default).
    When background is set to False, this method waits for command termination and returns
    command's result. If there is an exception while waiting for the result (i.e. connection reset)
    it returns -1.

    :param driver_addresses: driver's addresses
    :param key: used for encryption of parameters passed across the hosts
    :param settings: settings
    :param host_hash: host hash to connect to
    :param command: command and arguments to invoke
    :param env: environment to use
    :param local_rank: local rank on the host of task to run the command in
    :param background: run command in background if True, returns command result otherwise
    :param events: events to abort the command, only if background is True
    """
    if ':' in host_hash:
        raise Exception('Illegal host hash provided. Are you using Open MPI 4.0.0+?')

    driver_client = driver_service.SparkDriverClient(driver_addresses, key,
                                                     verbose=settings.verbose)
    task_indices = driver_client.task_host_hash_indices(host_hash)
    task_index = task_indices[local_rank]
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index, task_addresses,
                                               key, verbose=settings.verbose)
    task_client.run_command(command, env)

    if not background:
        stop = None
        events = events or []
        for event in events:
            stop = threading.Event()
            on_event(event, task_client.abort_command, stop=stop)

        try:
            return task_client.wait_for_command_exit_code()
        except:
            traceback.print_exc()
            return -1
        finally:
            if stop is not None:
                stop.set()
예제 #9
0
def main(driver_addresses, settings, host_hash, command):
    if ':' in host_hash:
        raise Exception('Illegal host hash provided. Are you using Open MPI 4.0.0+?')

    key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY])
    driver_client = driver_service.SparkDriverClient(driver_addresses, key,
                                                     verbose=settings.verbose)
    task_indices = driver_client.task_host_hash_indices(host_hash)
    # Since tasks with the same host hash have shared memory, we will run only
    # one ORTED process on the first task.
    first_task_index = task_indices[0]
    task_addresses = driver_client.all_task_addresses(first_task_index)
    task_client = task_service.SparkTaskClient(first_task_index, task_addresses,
                                               key, verbose=settings.verbose)
    task_client.run_command(command, os.environ)
예제 #10
0
def main(driver_addresses):
    # Die if parent process terminates
    bg = threading.Thread(target=parent_process_monitor, args=(os.getppid(), ))
    bg.daemon = True
    bg.start()

    key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY])
    rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
    driver_client = driver_service.SparkDriverClient(driver_addresses, key)
    task_index = driver_client.task_index_by_rank(rank)
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index, task_addresses, key)
    fn, args, kwargs = driver_client.code()
    result = fn(*args, **kwargs)
    task_client.register_code_result(result)
예제 #11
0
def _task_fn(index, driver_addresses, key, settings, use_gloo):
    # deserialized on Spark workers, settings do not contain the key, so it is given here explicitly
    # Spark RPC communicates the key and supports encryption
    # for convenience, we put it back into settings
    settings.key = key

    task = task_service.SparkTaskService(index, settings.key, settings.nics,
                                         settings.verbose)
    try:
        driver_client = driver_service.SparkDriverClient(
            driver_addresses, settings.key, settings.verbose)
        driver_client.register_task(index, task.addresses(),
                                    host_hash.host_hash())
        task.wait_for_initial_registration(settings.timeout)
        task_indices_on_this_host = driver_client.task_host_hash_indices(
            host_hash.host_hash())

        # With Gloo all tasks wait for the command
        # With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks.
        minimum_lifetime_after_start = None
        if use_gloo or task_indices_on_this_host[0] == index:
            task.wait_for_command_start(settings.timeout)
            minimum_lifetime_after_start = timeout.Timeout(
                MINIMUM_COMMAND_LIFETIME_S, message='Just measuring runtime')
            task.wait_for_command_termination()
        else:
            # The rest of tasks need to wait for the first task to finish.
            first_task_addresses = driver_client.all_task_addresses(
                task_indices_on_this_host[0])
            first_task_client = \
                task_service.SparkTaskClient(task_indices_on_this_host[0],
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            first_task_client.wait_for_command_termination()

        # command terminated, make sure this task service does not shutdown too quickly after
        # the client started the command as it needs some time to connect again
        # to wait for the result after starting the command (see horovod.spark.driver.rsh).
        if minimum_lifetime_after_start is not None:
            time.sleep(minimum_lifetime_after_start.remaining())

        return task.fn_result()
    finally:
        # this has to block on running requests (wait_for_command_exit_code)
        # so they can finish serving the exit code
        # shutdown does block with network.BasicService._server._block_on_close = True
        task.shutdown()
예제 #12
0
def task_exec(driver_addresses, settings, rank_env):
    # Die if parent process terminates
    in_thread(target=_parent_process_monitor, args=(os.getppid(), ))

    key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY])
    rank = int(os.environ[rank_env])
    driver_client = driver_service.SparkDriverClient(driver_addresses,
                                                     key,
                                                     verbose=settings.verbose)
    task_index = driver_client.task_index_by_rank(rank)
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index,
                                               task_addresses,
                                               key,
                                               verbose=settings.verbose)
    task_info.set_resources(task_client.resources())

    fn, args, kwargs = driver_client.code()
    result = fn(*args, **kwargs)
    task_client.register_code_result(result)
예제 #13
0
def run(fn, args=(), kwargs={}, num_proc=None, start_timeout=None, env=None,
        stdout=None, stderr=None, verbose=1):
    """
    Runs Horovod in Spark.  Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.

    Args:
        fn: Function to run.
        args: Arguments to pass to `fn`.
        kwargs: Keyword arguments to pass to `fn`.
        num_proc: Number of Horovod processes.  Defaults to `spark.default.parallelism`.
        start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
                       If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        env: Environment dictionary to use in Horovod run.  Defaults to `os.environ`.
        stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout.
        stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr.
        verbose: Debug output verbosity (0-2). Defaults to 1.

    Returns:
        List of results returned by running `fn` on each rank.
    """

    if start_timeout is None:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please check that you have '
                                    'enough resources to run all Horovod processes. Each Horovod '
                                    'process runs in a Spark task. You may need to increase the '
                                    'start_timeout parameter to a larger value if your Spark resources '
                                    'are allocated on-demand.')
    settings = hvd_settings.Settings(verbose=verbose,
                                     key=secret.make_secret_key(),
                                     timeout=tmout)

    spark_context = pyspark.SparkContext._active_spark_context
    if spark_context is None:
        raise Exception('Could not find an active SparkContext, are you '
                        'running in a PySpark session?')

    if num_proc is None:
        num_proc = spark_context.defaultParallelism
        if settings.verbose >= 1:
            print('Running %d processes (inferred from spark.default.parallelism)...' % num_proc)
    else:
        if settings.verbose >= 1:
            print('Running %d processes...' % num_proc)
    settings.num_proc = num_proc

    result_queue = queue.Queue(1)

    spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
    driver = driver_service.SparkDriverService(settings.num_proc, fn, args, kwargs,
                                               settings.key)
    spark_thread = _make_spark_thread(spark_context, spark_job_group, driver,
                                      result_queue, settings)
    try:
        driver.wait_for_initial_registration(settings.timeout)
        if settings.verbose >= 2:
            print('Initial Spark task registration is complete.')
        task_clients = [
            task_service.SparkTaskClient(index,
                                         driver.task_addresses_for_driver(index),
                                         settings.key, settings.verbose)
            for index in range(settings.num_proc)]
        for task_client in task_clients:
            task_client.notify_initial_registration_complete()
        driver.wait_for_task_to_task_address_updates(settings.timeout)
        if settings.verbose >= 2:
            print('Spark task-to-task address registration is complete.')

        # Determine a set of common interfaces for task-to-task communication.
        common_intfs = set(driver.task_addresses_for_tasks(0).keys())
        for index in range(1, settings.num_proc):
            common_intfs.intersection_update(driver.task_addresses_for_tasks(index).keys())
        if not common_intfs:
            raise Exception('Unable to find a set of common task-to-task communication interfaces: %s'
                            % [(index, driver.task_addresses_for_tasks(index)) for index in range(settings.num_proc)])

        # Determine the index grouping based on host hashes.
        # Barrel shift until index 0 is in the first host.
        host_hashes = list(driver.task_host_hash_indices().keys())
        host_hashes.sort()
        while 0 not in driver.task_host_hash_indices()[host_hashes[0]]:
            host_hashes = host_hashes[1:] + host_hashes[:1]

        ranks_to_indices = []
        for host_hash in host_hashes:
            ranks_to_indices += driver.task_host_hash_indices()[host_hash]
        driver.set_ranks_to_indices(ranks_to_indices)

        if env is None:
            env = os.environ.copy()

        # Pass secret key through the environment variables.
        env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(settings.key)

        mpirun_command = (
            'mpirun --allow-run-as-root --tag-output '
            '-np {num_proc} -H {hosts} '
            '-bind-to none -map-by slot '
            '-mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include {common_intfs} '
            '-x NCCL_DEBUG=INFO -x NCCL_SOCKET_IFNAME={common_intfs} '
            '{env} '  # expect a lot of environment variables
            '-mca plm_rsh_agent "{python} -m horovod.spark.driver.mpirun_rsh {encoded_driver_addresses} {settings}" '
            '{python} -m horovod.spark.task.mpirun_exec_fn {encoded_driver_addresses} {settings}'
                .format(num_proc=settings.num_proc,
                        hosts=','.join('%s:%d' % (host_hash, len(driver.task_host_hash_indices()[host_hash]))
                                       for host_hash in host_hashes),
                        common_intfs=','.join(common_intfs),
                        env=' '.join('-x %s' % key for key in env.keys() if env_util.is_exportable(key)),
                        python=sys.executable,
                        encoded_driver_addresses=codec.dumps_base64(driver.addresses()),
                        settings=codec.dumps_base64(settings)))
        if settings.verbose >= 2:
            print('+ %s' % mpirun_command)
        exit_code = safe_shell_exec.execute(mpirun_command, env, stdout, stderr)
        if exit_code != 0:
            raise Exception('mpirun exited with code %d, see the error above.' % exit_code)
    except:
        # Terminate Spark job.
        spark_context.cancelJobGroup(spark_job_group)

        # Re-raise exception.
        raise
    finally:
        spark_thread.join()
        driver.shutdown()

    # Make sure Spark Job did not fail.
    driver.check_for_spark_job_failure()

    # If there's no exception, execution results are in this queue.
    results = result_queue.get_nowait()
    return [results[index] for index in ranks_to_indices]
예제 #14
0
def run(fn,
        args=(),
        kwargs={},
        num_proc=None,
        start_timeout=None,
        use_mpi=None,
        use_gloo=None,
        extra_mpi_args=None,
        env=None,
        stdout=None,
        stderr=None,
        verbose=1,
        nics=None):
    """
    Runs Horovod in Spark.  Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.

    Args:
        fn: Function to run.
        args: Arguments to pass to `fn`.
        kwargs: Keyword arguments to pass to `fn`.
        num_proc: Number of Horovod processes.  Defaults to `spark.default.parallelism`.
        start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
                       If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        extra_mpi_args: Extra arguments for mpi_run. Defaults to no extra args.
        env: Environment dictionary to use in Horovod run.  Defaults to `os.environ`.
        stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout.
        stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr.
        verbose: Debug output verbosity (0-2). Defaults to 1.
        nics: List of NICs for tcp network communication.

    Returns:
        List of results returned by running `fn` on each rank.
    """

    if start_timeout is None:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))

    # nics needs to be a set
    if nics and not isinstance(nics, set):
        nics = set(nics)

    tmout = timeout.Timeout(
        start_timeout,
        message='Timed out waiting for {activity}. Please check that you have '
        'enough resources to run all Horovod processes. Each Horovod '
        'process runs in a Spark task. You may need to increase the '
        'start_timeout parameter to a larger value if your Spark resources '
        'are allocated on-demand.')
    settings = hvd_settings.Settings(verbose=verbose,
                                     extra_mpi_args=extra_mpi_args,
                                     key=secret.make_secret_key(),
                                     timeout=tmout,
                                     nics=nics,
                                     run_func_mode=True)

    spark_context = pyspark.SparkContext._active_spark_context
    if spark_context is None:
        raise Exception('Could not find an active SparkContext, are you '
                        'running in a PySpark session?')

    if num_proc is None:
        num_proc = spark_context.defaultParallelism
        if settings.verbose >= 1:
            print(
                'Running %d processes (inferred from spark.default.parallelism)...'
                % num_proc)
    else:
        if settings.verbose >= 1:
            print('Running %d processes...' % num_proc)
    settings.num_proc = num_proc

    result_queue = queue.Queue(1)

    # start Spark driver service and launch settings.num_proc Spark tasks
    spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
    driver = driver_service.SparkDriverService(settings.num_proc, fn, args,
                                               kwargs, settings.key,
                                               settings.nics)
    spark_thread = _make_spark_thread(spark_context, spark_job_group, driver,
                                      result_queue, settings, use_gloo)
    try:
        # wait for all tasks to register and notify them
        driver.wait_for_initial_registration(settings.timeout)
        if settings.verbose >= 2:
            print('Initial Spark task registration is complete.')
        task_clients = [
            task_service.SparkTaskClient(
                index, driver.task_addresses_for_driver(index), settings.key,
                settings.verbose) for index in range(settings.num_proc)
        ]
        for task_client in task_clients:
            task_client.notify_initial_registration_complete()
        driver.wait_for_task_to_task_address_updates(settings.timeout)
        if settings.verbose >= 2:
            print('Spark task-to-task address registration is complete.')

        # Determine the index grouping based on host hashes.
        # Barrel shift until index 0 is in the first host.
        host_hashes = list(driver.task_host_hash_indices().keys())
        host_hashes.sort()
        while 0 not in driver.task_host_hash_indices()[host_hashes[0]]:
            host_hashes = host_hashes[1:] + host_hashes[:1]

        settings.hosts = ','.join(
            '%s:%d' %
            (host_hash, len(driver.task_host_hash_indices()[host_hash]))
            for host_hash in host_hashes)

        # Determine the ranks to indicies
        ranks_to_indices = []
        for host_hash in host_hashes:
            ranks_to_indices += driver.task_host_hash_indices()[host_hash]
        driver.set_ranks_to_indices(ranks_to_indices)

        # Run the job
        _launch_job(use_mpi, use_gloo, settings, driver, env, stdout, stderr)
    except:
        # Terminate Spark job.
        spark_context.cancelJobGroup(spark_job_group)

        # Re-raise exception.
        raise
    finally:
        spark_thread.join()
        driver.shutdown()

    # Make sure Spark Job did not fail.
    driver.check_for_spark_job_failure()

    # If there's no exception, execution results are in this queue.
    results = result_queue.get_nowait()
    return [results[index] for index in ranks_to_indices]
예제 #15
0
파일: runner.py 프로젝트: lakersdf/horovod
def _task_fn(index, driver_addresses, key, settings, use_gloo, is_elastic):
    # deserialized on Spark workers, settings do not contain the key, so it is given here explicitly
    # Spark RPC communicates the key and supports encryption
    # for convenience, we put it back into settings
    settings.key = key

    # to simplify things, each task is an individual host in Elastic Horovod on Spark
    # further, each attempt (instance) of a task is an individual host in Elastic Horovod on Spark
    # hides availability of shared memory among executors on the same Spark node
    hosthash = host_hash(
        salt='{}-{}'.format(index, time.time()) if is_elastic else None)

    # provide host hash to mpirun_exec_fn.py via task service
    # gloo_exec_fn.py will get this env var set in request env as well
    os.environ['HOROVOD_HOSTNAME'] = hosthash

    task = task_service.SparkTaskService(
        index, settings.key, settings.nics,
        MINIMUM_COMMAND_LIFETIME_S if is_elastic or use_gloo else None,
        settings.verbose)
    try:
        driver_client = driver_service.SparkDriverClient(
            driver_addresses, settings.key, settings.verbose)
        driver_client.register_task(index, task.addresses(), hosthash)

        if not is_elastic:
            task.wait_for_initial_registration(settings.start_timeout)
            task_indices_on_this_host = driver_client.task_host_hash_indices(
                hosthash)
            local_rank_zero_index = task_indices_on_this_host[0]
        else:
            local_rank_zero_index = None

        # In elastic all tasks wait for task shutdown signal from driver.
        # With Gloo all tasks wait for the command to start and terminate.
        # With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks.
        if is_elastic:
            # either terminate on task shutdown or command termination
            shutdown_thread = in_thread(driver_client.wait_for_task_shutdown)

            while shutdown_thread.is_alive():
                # Once the command started we wait for its termination
                if task.check_for_command_start(
                        WAIT_FOR_COMMAND_START_DELAY_SECONDS):
                    task.wait_for_command_termination()
                    if task.command_exit_code() != 0:
                        raise Exception(
                            'Command failed, making Spark task fail to restart the task'
                        )
                    break

                # While no command started, we can shutdown any time
                shutdown_thread.join(WAIT_FOR_SHUTDOWN_DELAY_SECONDS)
        elif use_gloo or index == local_rank_zero_index:
            # Either Gloo or first task with MPI.
            task.wait_for_command_start(settings.start_timeout)
            task.wait_for_command_termination()
        else:
            # The other tasks with MPI need to wait for the first task to finish.
            first_task_addresses = driver_client.all_task_addresses(
                local_rank_zero_index)
            first_task_client = \
                task_service.SparkTaskClient(local_rank_zero_index,
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            first_task_client.wait_for_command_termination()

        return task.fn_result()
    finally:
        # we must not call into shutdown too quickly, task clients run a command
        # and want to wait on the result, we have told task service not to return
        # from wait_for_command_termination too quickly, so we are safe here to shutdown
        # clients have had enough time to connect to the service already
        #
        # the shutdown has to block on running requests (wait_for_command_exit_code)
        # so they can finish serving the exit code
        # shutdown does block with network.BasicService._server._block_on_close = True
        task.shutdown()