コード例 #1
0
def _task_fn(index, driver_addresses, num_proc, tmout, key):
    task = task_service.TaskService(index, key)
    try:
        driver_client = driver_service.DriverClient(driver_addresses, key)
        driver_client.register_task(index, task.addresses(),
                                    host_hash.host_hash())
        task.wait_for_initial_registration(tmout)
        # Tasks ping each other in a circular fashion to determine interfaces reachable within
        # the cluster.
        next_task_index = (index + 1) % num_proc
        next_task_addresses = driver_client.all_task_addresses(next_task_index)
        # We request interface matching to weed out all the NAT'ed interfaces.
        next_task_client = task_service.TaskClient(next_task_index,
                                                   next_task_addresses,
                                                   key,
                                                   match_intf=True)
        driver_client.register_task_to_task_addresses(
            next_task_index, next_task_client.addresses())
        task_indices_on_this_host = driver_client.task_host_hash_indices(
            host_hash.host_hash())
        if task_indices_on_this_host[0] == index:
            # Task with first index will execute orted that will run mpirun_exec_fn for all tasks.
            task.wait_for_command_start(tmout)
            task.wait_for_command_termination()
        else:
            # The rest of tasks need to wait for the first task to finish.
            first_task_addresses = driver_client.all_task_addresses(
                task_indices_on_this_host[0])
            first_task_client = task_service.TaskClient(
                task_indices_on_this_host[0], first_task_addresses, key)
            first_task_client.wait_for_command_termination()
        return task.fn_result()
    finally:
        task.shutdown()
コード例 #2
0
ファイル: mpirun_rsh.py プロジェクト: zxw866/horovod
def main(driver_addresses, host_hash, command):
    key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY])
    driver_client = driver_service.DriverClient(driver_addresses, key)
    task_indices = driver_client.task_host_hash_indices(host_hash)
    # Since tasks with the same host hash have shared memory, we will run only
    # one ORTED process on the first task.
    first_task_index = task_indices[0]
    task_addresses = driver_client.all_task_addresses(first_task_index)
    task_client = task_service.TaskClient(first_task_index, task_addresses,
                                          key)
    task_client.run_command(command, os.environ)
コード例 #3
0
ファイル: mpirun_rsh.py プロジェクト: PaulGureghian1/Horovod
def main(driver_addresses, host_hash, command):
    if ':' in host_hash:
        raise Exception(
            'Illegal host hash provided. Are you using Open MPI 4.0.0+?')

    key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY])
    driver_client = driver_service.DriverClient(driver_addresses, key)
    task_indices = driver_client.task_host_hash_indices(host_hash)
    # Since tasks with the same host hash have shared memory, we will run only
    # one ORTED process on the first task.
    first_task_index = task_indices[0]
    task_addresses = driver_client.all_task_addresses(first_task_index)
    task_client = task_service.TaskClient(first_task_index, task_addresses,
                                          key)
    task_client.run_command(command, os.environ)
コード例 #4
0
def main(driver_addresses):
    # Die if parent process terminates
    bg = threading.Thread(target=parent_process_monitor, args=(os.getppid(), ))
    bg.daemon = True
    bg.start()

    key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY])
    rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
    driver_client = driver_service.DriverClient(driver_addresses, key)
    task_index = driver_client.task_index_by_rank(rank)
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.TaskClient(task_index, task_addresses, key)
    fn, args, kwargs = driver_client.code()
    result = fn(*args, **kwargs)
    task_client.register_code_result(result)
コード例 #5
0
def run(fn,
        args=(),
        kwargs={},
        num_proc=None,
        start_timeout=None,
        env=None,
        stdout=None,
        stderr=None,
        verbose=1):
    """
    Runs Horovod in Spark.  Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.

    Args:
        fn: Function to run.
        args: Arguments to pass to `fn`.
        kwargs: Keyword arguments to pass to `fn`.
        num_proc: Number of Horovod processes.  Defaults to `spark.default.parallelism`.
        start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
                       If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        env: Environment dictionary to use in Horovod run.  Defaults to `os.environ`.
        stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout.
        stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr.
        verbose: Debug output verbosity (0-2). Defaults to 1.

    Returns:
        List of results returned by running `fn` on each rank.
    """
    spark_context = pyspark.SparkContext._active_spark_context
    if spark_context is None:
        raise Exception(
            'Could not find an active SparkContext, are you running in a PySpark session?'
        )

    if num_proc is None:
        num_proc = spark_context.defaultParallelism
        if verbose >= 1:
            print(
                'Running %d processes (inferred from spark.default.parallelism)...'
                % num_proc)
    else:
        if verbose >= 1:
            print('Running %d processes...' % num_proc)

    if start_timeout is None:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))

    result_queue = queue.Queue(1)
    tmout = timeout.Timeout(start_timeout)
    key = secret.make_secret_key()
    spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
    driver = driver_service.DriverService(num_proc, fn, args, kwargs, key)
    spark_thread = _make_spark_thread(spark_context, spark_job_group, num_proc,
                                      driver, tmout, key, result_queue)
    try:
        driver.wait_for_initial_registration(tmout)
        if verbose >= 2:
            print('Initial Spark task registration is complete.')
        task_clients = [
            task_service.TaskClient(index,
                                    driver.task_addresses_for_driver(index),
                                    key) for index in range(num_proc)
        ]
        for task_client in task_clients:
            task_client.notify_initial_registration_complete()
        driver.wait_for_task_to_task_address_updates(tmout)
        if verbose >= 2:
            print('Spark task-to-task address registration is complete.')

        # Determine a set of common interfaces for task-to-task communication.
        common_intfs = set(driver.task_addresses_for_tasks(0).keys())
        for index in range(1, num_proc):
            common_intfs.intersection_update(
                driver.task_addresses_for_tasks(index).keys())
        if not common_intfs:
            raise Exception(
                'Unable to find a set of common task-to-task communication interfaces: %s'
                % [(index, driver.task_addresses_for_tasks(index))
                   for index in range(num_proc)])

        # Determine the index grouping based on host hashes.
        # Barrel shift until index 0 is in the first host.
        host_hashes = list(driver.task_host_hash_indices().keys())
        host_hashes.sort()
        while 0 not in driver.task_host_hash_indices()[host_hashes[0]]:
            host_hashes = host_hashes[1:] + host_hashes[:1]

        ranks_to_indices = []
        for host_hash in host_hashes:
            ranks_to_indices += driver.task_host_hash_indices()[host_hash]
        driver.set_ranks_to_indices(ranks_to_indices)

        if env is None:
            env = os.environ.copy()

        # Pass secret key through the environment variables.
        env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(key)

        mpirun_command = (
            'mpirun --allow-run-as-root --tag-output '
            '-np {num_proc} -H {hosts} '
            '-bind-to none -map-by slot '
            '-mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include {common_intfs} '
            '-x NCCL_DEBUG=INFO -x NCCL_SOCKET_IFNAME={common_intfs} '
            '{env} '  # expect a lot of environment variables
            '-mca plm_rsh_agent "{python} -m horovod.spark.driver.mpirun_rsh {encoded_driver_addresses}" '
            '{python} -m horovod.spark.task.mpirun_exec_fn {encoded_driver_addresses} '
            .format(
                num_proc=num_proc,
                hosts=','.join(
                    '%s:%d' % (host_hash,
                               len(driver.task_host_hash_indices()[host_hash]))
                    for host_hash in host_hashes),
                common_intfs=','.join(common_intfs),
                env=' '.join('-x %s' % key for key in env.keys()),
                python=sys.executable,
                encoded_driver_addresses=codec.dumps_base64(
                    driver.addresses())))
        if verbose >= 2:
            print('+ %s' % mpirun_command)
        exit_code = safe_shell_exec.execute(mpirun_command, env, stdout,
                                            stderr)
        if exit_code != 0:
            raise Exception(
                'mpirun exited with code %d, see the error above.' % exit_code)
    except:
        # Terminate Spark job.
        spark_context.cancelJobGroup(spark_job_group)

        # Re-raise exception.
        raise
    finally:
        spark_thread.join()
        driver.shutdown()

    # Make sure Spark Job did not fail.
    driver.check_for_spark_job_failure()

    # If there's no exception, execution results are in this queue.
    results = result_queue.get_nowait()
    return [results[index] for index in ranks_to_indices]