예제 #1
0
def _task_fn(index, driver_addresses, settings):
    task = task_service.SparkTaskService(index, settings.key)
    try:
        driver_client = driver_service.SparkDriverClient(driver_addresses, settings.key, settings.verbose)
        driver_client.register_task(index, task.addresses(), host_hash.host_hash())
        task.wait_for_initial_registration(settings.timeout)
        # Tasks ping each other in a circular fashion to determine interfaces reachable within
        # the cluster.
        next_task_index = (index + 1) % settings.num_proc
        next_task_addresses = driver_client.all_task_addresses(next_task_index)
        # We request interface matching to weed out all the NAT'ed interfaces.
        next_task_client = \
            task_service.SparkTaskClient(next_task_index, next_task_addresses,
                                         settings.key, settings.verbose,
                                         match_intf=True)
        driver_client.register_task_to_task_addresses(next_task_index, next_task_client.addresses())
        task_indices_on_this_host = driver_client.task_host_hash_indices(
            host_hash.host_hash())
        if task_indices_on_this_host[0] == index:
            # Task with first index will execute orted that will run mpirun_exec_fn for all tasks.
            task.wait_for_command_start(settings.timeout)
            task.wait_for_command_termination()
        else:
            # The rest of tasks need to wait for the first task to finish.
            first_task_addresses = driver_client.all_task_addresses(task_indices_on_this_host[0])
            first_task_client = \
                task_service.SparkTaskClient(task_indices_on_this_host[0],
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            first_task_client.wait_for_command_termination()
        return task.fn_result()
    finally:
        task.shutdown()
예제 #2
0
파일: runner.py 프로젝트: yunzhezyz/horovod
def _task_fn(index, driver_addresses, key, settings, use_gloo):
    # deserialized on Spark workers, settings do not contain the key, so it is given here explicitly
    # Spark RPC communicates the key and supports encryption
    # for convenience, we put it back into settings
    settings.key = key

    task = task_service.SparkTaskService(index, settings.key, settings.nics,
                                         settings.verbose)
    try:
        driver_client = driver_service.SparkDriverClient(
            driver_addresses, settings.key, settings.verbose)
        driver_client.register_task(index, task.addresses(),
                                    host_hash.host_hash())
        task.wait_for_initial_registration(settings.timeout)
        task_indices_on_this_host = driver_client.task_host_hash_indices(
            host_hash.host_hash())

        # With Gloo all tasks wait for the command
        # With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks.
        if use_gloo or task_indices_on_this_host[0] == index:
            task.wait_for_command_start(settings.timeout)
            task.wait_for_command_termination()
        else:
            # The rest of tasks need to wait for the first task to finish.
            first_task_addresses = driver_client.all_task_addresses(
                task_indices_on_this_host[0])
            first_task_client = \
                task_service.SparkTaskClient(task_indices_on_this_host[0],
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            first_task_client.wait_for_command_termination()
        return task.fn_result()
    finally:
        task.shutdown()
예제 #3
0
def _task_fn(index, driver_addresses, key, settings, use_gloo):
    # deserialized on Spark workers, settings do not contain the key, so it is given here explicitly
    # Spark RPC communicates the key and supports encryption
    # for convenience, we put it back into settings
    settings.key = key

    task = task_service.SparkTaskService(index, settings.key, settings.nics,
                                         settings.verbose)
    try:
        driver_client = driver_service.SparkDriverClient(
            driver_addresses, settings.key, settings.verbose)
        driver_client.register_task(index, task.addresses(),
                                    host_hash.host_hash())
        task.wait_for_initial_registration(settings.timeout)
        task_indices_on_this_host = driver_client.task_host_hash_indices(
            host_hash.host_hash())

        # With Gloo all tasks wait for the command
        # With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks.
        minimum_lifetime_after_start = None
        if use_gloo or task_indices_on_this_host[0] == index:
            task.wait_for_command_start(settings.timeout)
            minimum_lifetime_after_start = timeout.Timeout(
                MINIMUM_COMMAND_LIFETIME_S, message='Just measuring runtime')
            task.wait_for_command_termination()
        else:
            # The rest of tasks need to wait for the first task to finish.
            first_task_addresses = driver_client.all_task_addresses(
                task_indices_on_this_host[0])
            first_task_client = \
                task_service.SparkTaskClient(task_indices_on_this_host[0],
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            first_task_client.wait_for_command_termination()

        # command terminated, make sure this task service does not shutdown too quickly after
        # the client started the command as it needs some time to connect again
        # to wait for the result after starting the command (see horovod.spark.driver.rsh).
        if minimum_lifetime_after_start is not None:
            time.sleep(minimum_lifetime_after_start.remaining())

        return task.fn_result()
    finally:
        # this has to block on running requests (wait_for_command_exit_code)
        # so they can finish serving the exit code
        # shutdown does block with network.BasicService._server._block_on_close = True
        task.shutdown()
예제 #4
0
파일: runner.py 프로젝트: lakersdf/horovod
def _task_fn(index, driver_addresses, key, settings, use_gloo, is_elastic):
    # deserialized on Spark workers, settings do not contain the key, so it is given here explicitly
    # Spark RPC communicates the key and supports encryption
    # for convenience, we put it back into settings
    settings.key = key

    # to simplify things, each task is an individual host in Elastic Horovod on Spark
    # further, each attempt (instance) of a task is an individual host in Elastic Horovod on Spark
    # hides availability of shared memory among executors on the same Spark node
    hosthash = host_hash(
        salt='{}-{}'.format(index, time.time()) if is_elastic else None)

    # provide host hash to mpirun_exec_fn.py via task service
    # gloo_exec_fn.py will get this env var set in request env as well
    os.environ['HOROVOD_HOSTNAME'] = hosthash

    task = task_service.SparkTaskService(
        index, settings.key, settings.nics,
        MINIMUM_COMMAND_LIFETIME_S if is_elastic or use_gloo else None,
        settings.verbose)
    try:
        driver_client = driver_service.SparkDriverClient(
            driver_addresses, settings.key, settings.verbose)
        driver_client.register_task(index, task.addresses(), hosthash)

        if not is_elastic:
            task.wait_for_initial_registration(settings.start_timeout)
            task_indices_on_this_host = driver_client.task_host_hash_indices(
                hosthash)
            local_rank_zero_index = task_indices_on_this_host[0]
        else:
            local_rank_zero_index = None

        # In elastic all tasks wait for task shutdown signal from driver.
        # With Gloo all tasks wait for the command to start and terminate.
        # With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks.
        if is_elastic:
            # either terminate on task shutdown or command termination
            shutdown_thread = in_thread(driver_client.wait_for_task_shutdown)

            while shutdown_thread.is_alive():
                # Once the command started we wait for its termination
                if task.check_for_command_start(
                        WAIT_FOR_COMMAND_START_DELAY_SECONDS):
                    task.wait_for_command_termination()
                    if task.command_exit_code() != 0:
                        raise Exception(
                            'Command failed, making Spark task fail to restart the task'
                        )
                    break

                # While no command started, we can shutdown any time
                shutdown_thread.join(WAIT_FOR_SHUTDOWN_DELAY_SECONDS)
        elif use_gloo or index == local_rank_zero_index:
            # Either Gloo or first task with MPI.
            task.wait_for_command_start(settings.start_timeout)
            task.wait_for_command_termination()
        else:
            # The other tasks with MPI need to wait for the first task to finish.
            first_task_addresses = driver_client.all_task_addresses(
                local_rank_zero_index)
            first_task_client = \
                task_service.SparkTaskClient(local_rank_zero_index,
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            first_task_client.wait_for_command_termination()

        return task.fn_result()
    finally:
        # we must not call into shutdown too quickly, task clients run a command
        # and want to wait on the result, we have told task service not to return
        # from wait_for_command_termination too quickly, so we are safe here to shutdown
        # clients have had enough time to connect to the service already
        #
        # the shutdown has to block on running requests (wait_for_command_exit_code)
        # so they can finish serving the exit code
        # shutdown does block with network.BasicService._server._block_on_close = True
        task.shutdown()