Exemplo n.º 1
0
 def test_settings_dump_drops_key(self):
     settings = hvd_settings.Settings(verbose=2, key="a secret key")
     clone = codec.loads_base64(codec.dumps_base64(settings))
     self.assertEqual(settings.verbose, clone.verbose)
     self.assertIsNotNone(settings.key)
     self.assertIsNone(clone.key)
Exemplo n.º 2
0
def _save_meta_to_fs(fs, path, schema, rows, total_byte_size):
    with fs.open(path, 'wb') as train_meta_file:
        serialized_content = codec.dumps_base64(
            dict(schema=schema, rows=rows, total_byte_size=total_byte_size))
        train_meta_file.write(serialized_content.encode('utf-8'))
Exemplo n.º 3
0
def _launch_task_servers(all_host_names, local_host_names, driver_addresses,
                         settings):
    """
    Executes the task server and service client task for registration on the
    hosts.
    :param all_host_names: list of addresses. for example,
        ['worker-0','worker-1']
        ['10.11.11.11', '10.11.11.12']
    :type all_host_names: list(string)
    :param local_host_names: names that are resolved to one of the addresses
    of local hosts interfaces. For example,
        set(['localhost', '127.0.0.1'])
    :type local_host_names: set
    :param driver_addresses: map of interfaces and their address and port for
    the service. For example:
        {
            'lo': [('127.0.0.1', 34588)],
            'docker0': [('172.122.10.1', 34588)],
            'eth0': [('11.111.33.73', 34588)]
        }
    :type driver_addresses: map
    :param settings: the object that contains the setting for running horovod
    :type settings: Horovod.run.common.util.settings.Settings
    :return:
    :rtype:
    """
    def _exec_command(command):
        host_output = six.StringIO()
        try:
            exit_code = safe_shell_exec.execute(command,
                                                stdout=host_output,
                                                stderr=host_output)
            if exit_code != 0:
                print('Launching horovodrun task function was not '
                      'successful:\n{host_output}'.format(
                          host_output=host_output.getvalue()))
                os._exit(exit_code)
        finally:
            host_output.close()
        return exit_code

    ssh_port_args = _get_ssh_port_args(all_host_names,
                                       ssh_port=settings.ssh_port,
                                       ssh_ports=settings.ssh_ports)

    args_list = []
    for index in range(len(all_host_names)):
        host_name = all_host_names[index]
        if host_name in local_host_names:
            command = \
                '{python} -m horovod.run.task_fn {index} ' \
                '{driver_addresses} {settings}'\
                .format(python=sys.executable,
                        index=codec.dumps_base64(index),
                        driver_addresses=codec.dumps_base64(driver_addresses),
                        settings=codec.dumps_base64(settings))
        else:
            command = \
                'ssh -o StrictHostKeyChecking=no {host} {ssh_port_arg} ' \
                '\'{python} -m horovod.run.task_fn {index} {driver_addresses}' \
                ' {settings}\''\
                .format(host=host_name,
                        ssh_port_arg=ssh_port_args[index],
                        python=sys.executable,
                        index=codec.dumps_base64(index),
                        driver_addresses=codec.dumps_base64(driver_addresses),
                        settings=codec.dumps_base64(settings))
        args_list.append([command])
    # Each thread will use ssh command to launch the server on one task. If an
    # error occurs in one thread, entire process will be terminated. Otherwise,
    # threads will keep running and ssh session -- and the the task server --
    # will be bound to the thread. In case, the horovodrun process dies, all
    # the ssh sessions and all the task servers will die as well.
    threads.execute_function_multithreaded(_exec_command,
                                           args_list,
                                           block_until_all_done=False)
Exemplo n.º 4
0
def run():
    args = parse_args()

    if args.version:
        print(horovod.__version__)
        exit(0)

    if args.host:
        all_host_names = [x for x in
                          [y.split(':')[0] for y in args.host.split(',')]]
    else:
        all_host_names = []

    # This cache stores the results of checks performed by horovodrun
    # during the initialization step. It can be disabled by setting
    # --disable-cache flag.
    fn_cache = None
    if not args.disable_cache:
        params = ''
        if args.np:
            params += str(args.np) + ' '
        if args.host:
            params += str(args.host) + ' '
        if args.ssh_port:
            params += str(args.ssh_port)
        parameters_hash = hashlib.md5(params.encode('utf-8')).hexdigest()
        fn_cache = cache.Cache(CACHE_FOLDER, CACHE_STALENESS_THRESHOLD_MINUTES,
                               parameters_hash)

    # horovodrun has to finish all the checks before this timeout runs out.
    if args.start_timeout:
        start_timeout = args.start_timeout
    else:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '600'))
    tmout = timeout.Timeout(start_timeout)

    key = secret.make_secret_key()
    remote_host_names = []
    if args.host:
        if args.verbose:
            print("Filtering local host names.")
        remote_host_names = network.filter_local_addresses(all_host_names)

        if len(remote_host_names) > 0:
            if args.verbose:
                print("Checking ssh on all remote hosts.")
            # Check if we can ssh into all remote hosts successfully.
            _check_all_hosts_ssh_successful(remote_host_names, args.ssh_port,
                                            fn_cache=fn_cache)
            if args.verbose:
                print("SSH was successful into all the remote hosts.")

        hosts_arg = "-H {hosts}".format(hosts=args.host)
    else:
        # if user does not specify any hosts, mpirun by default uses local host.
        # There is no need to specify localhost.
        hosts_arg = ""

    if args.host and len(remote_host_names) > 0:
        if args.verbose:
            print("Testing interfaces on all the hosts.")

        local_host_names = set(all_host_names) - set(remote_host_names)
        # Find the set of common, routed interfaces on all the hosts (remote
        # and local) and specify it in the args to be used by NCCL. It is
        # expected that the following function will find at least one interface
        # otherwise, it will raise an exception.
        common_intfs = _driver_fn(key,
                                  all_host_names, local_host_names, tmout,
                                  args.ssh_port, args.verbose,
                                  fn_cache=fn_cache)

        tcp_intf_arg = "-mca btl_tcp_if_include {common_intfs}".format(
            common_intfs=','.join(common_intfs))
        nccl_socket_intf_arg = "-x NCCL_SOCKET_IFNAME={common_intfs}".format(
            common_intfs=','.join(common_intfs))

        if args.verbose:
            print("Interfaces on all the hosts were successfully checked.")
    else:
        # If all the given hosts are local, no need to specify the interfaces
        # because MPI does not use network for local execution.
        tcp_intf_arg = ""
        nccl_socket_intf_arg = ""

    # Pass all the env variables to the mpirun command.
    env = os.environ.copy()

    # Pass secret key through the environment variables.
    env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(key)

    if not _is_open_mpi_installed():
        raise Exception(
            'horovodrun convenience script currently only supports '
            'Open MPI.\n\n'
            'Choose one of:\n'
            '1. Install Open MPI 4.0.0+ and re-install Horovod '
            '(use --no-cache-dir pip option).\n'
            '2. Run distributed '
            'training script using the standard way provided by your'
            ' MPI distribution (usually mpirun, srun, or jsrun).')

    if args.ssh_port:
        ssh_port_arg = "-mca plm_rsh_args \"-p {ssh_port}\"".format(
            ssh_port=args.ssh_port)
    else:
        ssh_port_arg = ""

    mpirun_command = (
        'mpirun --allow-run-as-root --tag-output '
        '-np {num_proc} {hosts_arg} '
        '-bind-to none -map-by slot '
        '-mca pml ob1 -mca btl ^openib '
        '{ssh_port_arg} '
        '{tcp_intf_arg} '
        '-x NCCL_DEBUG=INFO '
        '{nccl_socket_intf_arg} '
        '{env} {command}'  # expect a lot of environment variables
            .format(num_proc=args.np,
                    hosts_arg=hosts_arg,
                    tcp_intf_arg=tcp_intf_arg,
                    nccl_socket_intf_arg=nccl_socket_intf_arg,
                    ssh_port_arg=ssh_port_arg,
                    env=' '.join('-x %s' % key for key in env.keys()),
                    command=' '.join(quote(par) for par in args.command))
    )

    if args.verbose:
        print(mpirun_command)
    # Execute the mpirun command.
    exit_code = safe_shell_exec.execute(mpirun_command, env)
    if exit_code != 0:
        raise Exception(
            'mpirun exited with code %d, see the error above.' % exit_code)
Exemplo n.º 5
0
    def train(serialized_model, train_rows, val_rows, avg_row_size):
        from petastorm import TransformSpec, make_reader, make_batch_reader

        k = get_keras()
        k.backend.set_floatx(floatx)

        hvd = get_horovod()
        hvd.init()
        pin_gpu(hvd, tf, k)

        if not user_shuffle_buffer_size:
            shuffle_buffer_size = calculate_shuffle_buffer_size(
                hvd, avg_row_size, train_rows / hvd.size())
        else:
            shuffle_buffer_size = user_shuffle_buffer_size

        # needs to be deserialized in the with scope
        with k.utils.custom_object_scope(custom_objects):
            model = deserialize_keras_model(
                serialized_model, lambda x: hvd.load_model(x))

        # Horovod: adjust learning rate based on number of processes.
        k.backend.set_value(model.optimizer.lr,
                            k.backend.get_value(model.optimizer.lr) * hvd.size())

        # Verbose mode 1 will print a progress bar
        verbose = user_verbose if hvd.rank() == 0 else 0

        transform_spec = None
        if transformation:
            transform_spec = TransformSpec(transformation)

        with remote_store.get_local_output_dir() as run_output_dir:
            callbacks = [
                # Horovod: broadcast initial variable states from rank 0 to all other processes.
                # This is necessary to ensure consistent initialization of all workers when
                # training is started with random weights or restored from a checkpoint.
                hvd.callbacks.BroadcastGlobalVariablesCallback(root_rank=0),

                # Horovod: average metrics among workers at the end of every epoch.
                #
                # Note: This callback must be in the list before the ReduceLROnPlateau,
                # TensorBoard, or other metrics-based callbacks.
                hvd.callbacks.MetricAverageCallback(),
            ]
            callbacks += user_callbacks

            # Horovod: save checkpoints only on the first worker to prevent other workers from
            # corrupting them.
            if hvd.rank() == 0:
                ckpt_file = os.path.join(run_output_dir, remote_store.checkpoint_filename)
                logs_dir = os.path.join(run_output_dir, remote_store.logs_subdir)

                callbacks.append(k.callbacks.ModelCheckpoint(ckpt_file))
                if remote_store.saving_runs:
                    callbacks.append(k.callbacks.TensorBoard(logs_dir))
                    callbacks.append(SyncCallback(run_output_dir, remote_store.sync, k))

            if train_steps_per_epoch is None:
                steps_per_epoch = int(math.ceil(train_rows / batch_size / hvd.size()))
            else:
                steps_per_epoch = train_steps_per_epoch

            if validation_steps_per_epoch is None:
                # math.ceil because if val_rows is smaller than batch_size we still get the at least
                # one step. float(val_rows) because val_rows/batch_size evaluates to zero before
                # math.ceil
                validation_steps = int(math.ceil(float(val_rows) / batch_size / hvd.size())) \
                    if should_validate else None
            else:
                validation_steps = validation_steps_per_epoch

            schema_fields = feature_columns + label_columns
            if sample_weight_col:
                schema_fields.append(sample_weight_col)

            # In general, make_batch_reader is faster than make_reader for reading the dataset.
            # However, we found out that make_reader performs data transformations much faster than
            # make_batch_reader with parallel worker processes. Therefore, the default reader
            # we choose is make_batch_reader unless there are data transformations.
            reader_factory_kwargs = dict()
            if transform_spec:
                reader_factory = make_reader
                reader_factory_kwargs['pyarrow_serialize'] = True
                is_batch_reader = False
            else:
                reader_factory = make_batch_reader
                is_batch_reader = True

            # Petastorm: read data from the store with the correct shard for this rank
            # setting num_epochs=None will cause an infinite iterator
            # and enables ranks to perform training and validation with
            # unequal number of samples
            with reader_factory(remote_store.train_data_path,
                                num_epochs=None,
                                cur_shard=hvd.rank(),
                                reader_pool_type='process',
                                workers_count=train_reader_worker_count,
                                shard_count=hvd.size(),
                                hdfs_driver=PETASTORM_HDFS_DRIVER,
                                schema_fields=schema_fields,
                                transform_spec=transform_spec,
                                **reader_factory_kwargs) as train_reader:
                with reader_factory(remote_store.val_data_path,
                                    num_epochs=None,
                                    cur_shard=hvd.rank(),
                                    reader_pool_type='process',
                                    workers_count=val_reader_worker_count,
                                    shard_count=hvd.size(),
                                    hdfs_driver=PETASTORM_HDFS_DRIVER,
                                    schema_fields=schema_fields,
                                    transform_spec=transform_spec,
                                    **reader_factory_kwargs) \
                    if should_validate else empty_batch_reader() as val_reader:

                    train_data = make_dataset(train_reader, shuffle_buffer_size,
                                              is_batch_reader, shuffle=True)
                    val_data = make_dataset(val_reader, shuffle_buffer_size,
                                            is_batch_reader, shuffle=False) \
                        if val_reader else None

                    history = fit(model, train_data, val_data, steps_per_epoch,
                                  validation_steps, callbacks, verbose)

            # Dataset API usage currently displays a wall of errors upon termination.
            # This global model registration ensures clean termination.
            # Tracked in https://github.com/tensorflow/tensorflow/issues/24570
            globals()['_DATASET_FINALIZATION_HACK'] = model

            if hvd.rank() == 0:
                with open(ckpt_file, 'rb') as f:
                    return history.history, codec.dumps_base64(f.read()), hvd.size()
Exemplo n.º 6
0
def run(fn, args=(), kwargs={}, num_proc=None, start_timeout=None, env=None, stdout=None, stderr=None, verbose=1):
    """
    Runs Horovod in Spark.  Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.

    Args:
        fn: Function to run.
        args: Arguments to pass to `fn`.
        kwargs: Keyword arguments to pass to `fn`.
        num_proc: Number of Horovod processes.  Defaults to `spark.default.parallelism`.
        start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
                       If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        env: Environment dictionary to use in Horovod run.  Defaults to `os.environ`.
        stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout.
        stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr.
        verbose: Debug output verbosity (0-2). Defaults to 1.

    Returns:
        List of results returned by running `fn` on each rank.
    """
    spark_context = pyspark.SparkContext._active_spark_context
    if spark_context is None:
        raise Exception('Could not find an active SparkContext, are you running in a PySpark session?')

    if num_proc is None:
        num_proc = spark_context.defaultParallelism
        if verbose >= 1:
            print('Running %d processes (inferred from spark.default.parallelism)...' % num_proc)
    else:
        if verbose >= 1:
            print('Running %d processes...' % num_proc)

    if start_timeout is None:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))

    result_queue = queue.Queue(1)
    tmout = timeout.Timeout(start_timeout)
    key = secret.make_secret_key()
    spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
    driver = driver_service.SparkDriverService(num_proc, fn, args, kwargs, key)
    spark_thread = _make_spark_thread(spark_context, spark_job_group, num_proc, driver, tmout, key, result_queue)
    try:
        driver.wait_for_initial_registration(tmout)
        if verbose >= 2:
            print('Initial Spark task registration is complete.')
        task_clients = [task_service.SparkTaskClient(index, driver.task_addresses_for_driver(index), key)
                        for index in range(num_proc)]
        for task_client in task_clients:
            task_client.notify_initial_registration_complete()
        driver.wait_for_task_to_task_address_updates(tmout)
        if verbose >= 2:
            print('Spark task-to-task address registration is complete.')

        # Determine a set of common interfaces for task-to-task communication.
        common_intfs = set(driver.task_addresses_for_tasks(0).keys())
        for index in range(1, num_proc):
            common_intfs.intersection_update(driver.task_addresses_for_tasks(index).keys())
        if not common_intfs:
            raise Exception('Unable to find a set of common task-to-task communication interfaces: %s'
                            % [(index, driver.task_addresses_for_tasks(index)) for index in range(num_proc)])

        # Determine the index grouping based on host hashes.
        # Barrel shift until index 0 is in the first host.
        host_hashes = list(driver.task_host_hash_indices().keys())
        host_hashes.sort()
        while 0 not in driver.task_host_hash_indices()[host_hashes[0]]:
            host_hashes = host_hashes[1:] + host_hashes[:1]

        ranks_to_indices = []
        for host_hash in host_hashes:
            ranks_to_indices += driver.task_host_hash_indices()[host_hash]
        driver.set_ranks_to_indices(ranks_to_indices)

        if env is None:
            env = os.environ.copy()

        # Pass secret key through the environment variables.
        env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(key)

        mpirun_command = (
            'mpirun --allow-run-as-root --tag-output '
            '-np {num_proc} -H {hosts} '
            '-bind-to none -map-by slot '
            '-mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include {common_intfs} '
            '-x NCCL_DEBUG=INFO -x NCCL_SOCKET_IFNAME={common_intfs} '
            '{env} '  # expect a lot of environment variables
            '-mca plm_rsh_agent "{python} -m horovod.spark.driver.mpirun_rsh {encoded_driver_addresses}" '
            '{python} -m horovod.spark.task.mpirun_exec_fn {encoded_driver_addresses} '
            .format(num_proc=num_proc,
                    hosts=','.join('%s:%d' % (host_hash, len(driver.task_host_hash_indices()[host_hash]))
                                   for host_hash in host_hashes),
                    common_intfs=','.join(common_intfs),
                    env=' '.join('-x %s' % key for key in env.keys() if env_util.is_exportable(key)),
                    python=sys.executable,
                    encoded_driver_addresses=codec.dumps_base64(driver.addresses())))
        if verbose >= 2:
            print('+ %s' % mpirun_command)
        exit_code = safe_shell_exec.execute(mpirun_command, env, stdout, stderr)
        if exit_code != 0:
            raise Exception('mpirun exited with code %d, see the error above.' % exit_code)
    except:
        # Terminate Spark job.
        spark_context.cancelJobGroup(spark_job_group)

        # Re-raise exception.
        raise
    finally:
        spark_thread.join()
        driver.shutdown()

    # Make sure Spark Job did not fail.
    driver.check_for_spark_job_failure()

    # If there's no exception, execution results are in this queue.
    results = result_queue.get_nowait()
    return [results[index] for index in ranks_to_indices]
Exemplo n.º 7
0
def _serialize_keras_model(model, save_model_fn):
    """Serialize model into byte array encoded into base 64."""
    bio = io.BytesIO()
    with h5py.File(bio, 'w') as f:
        save_model_fn(model, f)
    return codec.dumps_base64(bio.getvalue())
Exemplo n.º 8
0
def run(fn,
        args=(),
        kwargs={},
        num_proc=None,
        start_timeout=None,
        extra_mpi_args=None,
        env=None,
        stdout=None,
        stderr=None,
        verbose=1,
        nics=None,
        run_func=safe_shell_exec.execute):
    """
    Runs Horovod in Spark.  Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.

    Args:
        fn: Function to run.
        args: Arguments to pass to `fn`.
        kwargs: Keyword arguments to pass to `fn`.
        num_proc: Number of Horovod processes.  Defaults to `spark.default.parallelism`.
        start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
                       If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        extra_mpi_args: Extra arguments for mpi_run. Defaults to no extra args.
        env: Environment dictionary to use in Horovod run.  Defaults to `os.environ`.
        stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout.
        stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr.
        verbose: Debug output verbosity (0-2). Defaults to 1.
        nics: List of NICs for tcp network communication.
        run_func: Run function to use. Must have arguments 'command', 'env', 'stdout', 'stderr'.
                  Defaults to safe_shell_exec.execute.

    Returns:
        List of results returned by running `fn` on each rank.
    """

    if start_timeout is None:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))

    # nics needs to be a set
    if nics and not isinstance(nics, set):
        nics = set(nics)

    tmout = timeout.Timeout(
        start_timeout,
        message='Timed out waiting for {activity}. Please check that you have '
        'enough resources to run all Horovod processes. Each Horovod '
        'process runs in a Spark task. You may need to increase the '
        'start_timeout parameter to a larger value if your Spark resources '
        'are allocated on-demand.')
    settings = hvd_settings.Settings(verbose=verbose,
                                     extra_mpi_args=extra_mpi_args,
                                     key=secret.make_secret_key(),
                                     timeout=tmout,
                                     nics=nics,
                                     run_func_mode=True)

    spark_context = pyspark.SparkContext._active_spark_context
    if spark_context is None:
        raise Exception('Could not find an active SparkContext, are you '
                        'running in a PySpark session?')

    if num_proc is None:
        num_proc = spark_context.defaultParallelism
        if settings.verbose >= 1:
            print(
                'Running %d processes (inferred from spark.default.parallelism)...'
                % num_proc)
    else:
        if settings.verbose >= 1:
            print('Running %d processes...' % num_proc)
    settings.num_proc = num_proc

    result_queue = queue.Queue(1)

    spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
    driver = driver_service.SparkDriverService(settings.num_proc, fn, args,
                                               kwargs, settings.key,
                                               settings.nics)
    spark_thread = _make_spark_thread(spark_context, spark_job_group, driver,
                                      result_queue, settings)
    try:
        driver.wait_for_initial_registration(settings.timeout)
        if settings.verbose >= 2:
            print('Initial Spark task registration is complete.')
        task_clients = [
            task_service.SparkTaskClient(
                index, driver.task_addresses_for_driver(index), settings.key,
                settings.verbose) for index in range(settings.num_proc)
        ]
        for task_client in task_clients:
            task_client.notify_initial_registration_complete()
        driver.wait_for_task_to_task_address_updates(settings.timeout)
        if settings.verbose >= 2:
            print('Spark task-to-task address registration is complete.')

        # Determine a set of common interfaces for task-to-task communication.
        common_intfs = set(driver.task_addresses_for_tasks(0).keys())
        for index in range(1, settings.num_proc):
            common_intfs.intersection_update(
                driver.task_addresses_for_tasks(index).keys())
        if not common_intfs:
            raise Exception(
                'Unable to find a set of common task-to-task communication interfaces: %s'
                % [(index, driver.task_addresses_for_tasks(index))
                   for index in range(settings.num_proc)])

        # Determine the index grouping based on host hashes.
        # Barrel shift until index 0 is in the first host.
        host_hashes = list(driver.task_host_hash_indices().keys())
        host_hashes.sort()
        while 0 not in driver.task_host_hash_indices()[host_hashes[0]]:
            host_hashes = host_hashes[1:] + host_hashes[:1]

        settings.hosts = ','.join(
            '%s:%d' %
            (host_hash, len(driver.task_host_hash_indices()[host_hash]))
            for host_hash in host_hashes)

        ranks_to_indices = []
        for host_hash in host_hashes:
            ranks_to_indices += driver.task_host_hash_indices()[host_hash]
        driver.set_ranks_to_indices(ranks_to_indices)

        if env is None:
            env = os.environ.copy()

        # Pass secret key through the environment variables.
        env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(settings.key)

        rsh_agent = (sys.executable, '-m', 'horovod.spark.driver.mpirun_rsh',
                     codec.dumps_base64(driver.addresses()),
                     codec.dumps_base64(settings))
        settings.extra_mpi_args = (
            '{extra_mpi_args} -x NCCL_DEBUG=INFO -mca plm_rsh_agent "{rsh_agent}"'
            .format(extra_mpi_args=settings.extra_mpi_args
                    if settings.extra_mpi_args else '',
                    rsh_agent=' '.join(rsh_agent)))
        command = (sys.executable, '-m', 'horovod.spark.task.mpirun_exec_fn',
                   codec.dumps_base64(driver.addresses()),
                   codec.dumps_base64(settings))
        mpi_run(settings,
                common_intfs,
                env,
                command,
                stdout=stdout,
                stderr=stderr,
                run_func=run_func)
    except:
        # Terminate Spark job.
        spark_context.cancelJobGroup(spark_job_group)

        # Re-raise exception.
        raise
    finally:
        spark_thread.join()
        driver.shutdown()

    # Make sure Spark Job did not fail.
    driver.check_for_spark_job_failure()

    # If there's no exception, execution results are in this queue.
    results = result_queue.get_nowait()
    return [results[index] for index in ranks_to_indices]