예제 #1
0
    def test_mpi_run_full(self):
        if not mpi_available():
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        nics = ['eth0', 'eth1']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        tmout = timeout.Timeout(5, message='Timed out waiting for something.')
        settings = hvd_settings.Settings(
            verbose=0,
            ssh_port=1022,
            extra_mpi_args='>mpi-extra args go here<',
            binding_args='>binding args go here<',
            key=secret.make_secret_key(),
            start_timeout=tmout,
            num_hosts=1,
            num_proc=1,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)

        def mpi_impl_flags(tcp, env=None):
            return ["--mock-mpi-impl-flags"], []

        with mock.patch("horovod.run.mpi_run._get_mpi_implementation_flags",
                        side_effect=mpi_impl_flags) as impl:
            with mock.patch("horovod.run.mpi_run.safe_shell_exec.execute",
                            return_value=0) as execute:
                mpi_run(settings, nics, env, cmd, stdout=stdout, stderr=stderr)

                # assert call on _get_mpi_implementation_flags
                impl.assert_called_once_with(None, env=env)

                # call the mocked _get_mpi_implementation_flags method ourselves
                mpi_flags, _ = horovod.run.mpi_run._get_mpi_implementation_flags(
                    False)
                self.assertIsNotNone(mpi_flags)
                expected_command = (
                    'mpirun '
                    '--allow-run-as-root --tag-output '
                    '-np 1 -H >host names go here< '
                    '>binding args go here< '
                    '{mpi_flags} '
                    '-mca plm_rsh_args "-p 1022" '
                    '-mca btl_tcp_if_include eth0,eth1 -x NCCL_SOCKET_IFNAME=eth0,eth1 '
                    '--output-filename >output filename goes here< '
                    '-x env1 -x env2 '
                    '>mpi-extra args go here< '
                    'cmd arg1 arg2').format(mpi_flags=' '.join(mpi_flags))
                expected_env = {
                    'env1': 'val1',
                    'env2': 'val2',
                    'PATH': os.environ.get('PATH')
                }
                execute.assert_called_once_with(expected_command,
                                                env=expected_env,
                                                stdout=stdout,
                                                stderr=stderr)
예제 #2
0
    def _run_command(self, command, env, event):
        super(SparkTaskService, self)._run_command(command, env, event)

        if self._minimum_command_lifetime_s is not None:
            self._minimum_command_lifetime = timeout.Timeout(
                self._minimum_command_lifetime_s,
                message='Just measuring runtime')
예제 #3
0
    def __init__(self, ray_ctx, verbose=None, start_timeout=None):

        self.cores_per_node = ray_ctx.ray_node_cpu_cores
        self.num_nodes = ray_ctx.num_ray_nodes
        self.worker_class = make_horovod_worker(self.cores_per_node)
        self.remote_workers = [self.worker_class.remote() for i in range(0, self.num_nodes)]

        hosts = ray.get([worker.hostname.remote() for worker in self.remote_workers])
        hosts_spec, name_rank_to_id, host_to_size = _hosts_to_hosts_spec(hosts)
        self.host_alloc_plan = _allocate(",".join(hosts_spec), self.num_nodes)
        global_rendezv = RendezvousServer(True)
        global_rendezv_port = global_rendezv.start_server(self.host_alloc_plan)

        if start_timeout is None:
            start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30'))

        tmout = timeout.Timeout(start_timeout,
                                message='Timed out waiting for {activity}. Please '
                                        'check connectivity between servers. You '
                                        'may need to increase the --start-timeout '
                                        'parameter if you have too many servers.')

        all_host_names = [k for k in host_to_size]

        settings = hvd_settings.Settings(verbose=2 if verbose else 0,
                                         key=secret.make_secret_key(),
                                         timeout=tmout,
                                         num_hosts=len(all_host_names),
                                         num_proc=self.num_nodes,
                                         hosts=",".join(hosts_spec))

        common_intfs = _find_common_network_interface(host_to_size, name_rank_to_id,
                                                      self.remote_workers, settings)
        iface = list(common_intfs)[0]
        driver_ip = _get_driver_ip([iface])

        common_envs = {
            "HOROVOD_GLOO_RENDEZVOUS_ADDR": driver_ip,
            "HOROVOD_GLOO_RENDEZVOUS_PORT": str(global_rendezv_port),
            "HOROVOD_CONTROLLER": "gloo",
            "HOROVOD_CPU_OPERATIONS": "gloo",
            "HOROVOD_GLOO_IFACE": iface,
            "PYTHONUNBUFFERED": '1',
        }

        for key in os.environ:
            if key.startswith("HOROVOD"):
                common_envs[key] = os.environ[key]

        # todo support other Horovod envs
        self.per_worker_envs = [common_envs.copy() for _ in range(self.num_nodes)]
        for alloc_info in self.host_alloc_plan:
            key = (alloc_info.hostname, alloc_info.local_rank)
            local_envs = self.per_worker_envs[name_rank_to_id[key]]
            local_envs["HOROVOD_RANK"] = str(alloc_info.rank)
            local_envs["HOROVOD_SIZE"] = str(alloc_info.size)
            local_envs["HOROVOD_LOCAL_RANK"] = str(alloc_info.local_rank)
            local_envs["HOROVOD_LOCAL_SIZE"] = str(alloc_info.local_size)
            local_envs["HOROVOD_CROSS_RANK"] = str(alloc_info.cross_rank)
            local_envs["HOROVOD_CROSS_SIZE"] = str(alloc_info.cross_size)
예제 #4
0
    def wait_for_available_slots(self, min_np, min_hosts=1):
        extra_message = ' An elastic job also requires that at least two hosts ' \
                        'are available to resolve compatible network interfaces. If you know which interfaces ' \
                        'are compatible in your network, set `--nic` to skip this check.' if min_hosts > 1 else ''

        tmout = timeout.Timeout(
            self._timeout,
            message=
            'Timed out waiting for {{activity}}. Please check that you have '
            'enough resources to run at least {min_np} Horovod processes.{extra_message}'
            .format(min_np=min_np, extra_message=extra_message))

        self._wait_hosts_cond.acquire()
        try:
            while True:
                current_hosts = self._host_manager.current_hosts
                if current_hosts.count_available_slots() >= min_np and len(
                        current_hosts.available_hosts) >= min_hosts:
                    return current_hosts
                if self._shutdown.is_set():
                    raise RuntimeError(
                        'Job has been shutdown, see above error messages for details.'
                    )
                self._wait_hosts_cond.wait(tmout.remaining())
                tmout.check_time_out_for(
                    'minimum number of slots to become available')
        finally:
            self._wait_hosts_cond.release()
예제 #5
0
파일: runner.py 프로젝트: zpcalan/horovod
def _run_elastic(args):
    # construct host discovery component
    if args.host_discovery_script:
        discover_hosts = discovery.HostDiscoveryScript(
            args.host_discovery_script, args.slots)
    elif args.hosts:
        _, available_host_slots = hosts.parse_hosts_and_slots(args.hosts)
        if len(available_host_slots) < 2:
            raise ValueError(
                'Cannot run in fault tolerance mode with fewer than 2 hosts.')
        discover_hosts = discovery.FixedHosts(available_host_slots)
    else:
        raise ValueError(
            'One of --host-discovery-script, --hosts, or --hostnames must be provided'
        )

    # horovodrun has to finish all the checks before this timeout runs out.
    if args.start_timeout:
        start_timeout = args.start_timeout
    else:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30'))

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please '
                            'check connectivity between servers. You '
                            'may need to increase the --start-timeout '
                            'parameter if you have too many servers.')
    settings = elastic_settings.ElasticSettings(
        discovery=discover_hosts,
        min_np=args.min_np or args.np,
        max_np=args.max_np,
        elastic_timeout=args.elastic_timeout,
        reset_limit=args.reset_limit,
        num_proc=args.np,
        verbose=2 if args.verbose else 0,
        ssh_port=args.ssh_port,
        extra_mpi_args=args.mpi_args,
        key=secret.make_secret_key(),
        start_timeout=tmout,
        output_filename=args.output_filename,
        run_func_mode=args.run_func is not None,
        nics=args.nics)

    if not gloo_built(verbose=(settings.verbose >= 2)):
        raise ValueError(
            'Gloo support is required to use elastic training, but has not been built.  Ensure CMake is '
            'installed and reinstall Horovod with HOROVOD_WITH_GLOO=1 to debug the build error.'
        )

    env = os.environ.copy()
    config_parser.set_env_from_args(env, args)
    gloo_run_elastic(settings, env, args.command)
예제 #6
0
    def test_mpi_run_full(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        common_intfs = ['eth0', 'eth1']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        tmout = timeout.Timeout(5, message='Timed out waiting for something.')
        settings = hvd_settings.Settings(
            verbose=0,
            ssh_port=1022,
            extra_mpi_args='>mpi-extra args go here<',
            binding_args='>binding args go here<',
            key=secret.make_secret_key(),
            timeout=tmout,
            num_hosts=1,
            num_proc=1,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)
        run_func = MagicMock(return_value=0)

        mpi_run(settings,
                common_intfs,
                env,
                cmd,
                stdout=stdout,
                stderr=stderr,
                run_func=run_func)

        mpi_flags, _ = _get_mpi_implementation_flags(False)
        self.assertIsNotNone(mpi_flags)
        expected_command = (
            'mpirun '
            '--allow-run-as-root --tag-output '
            '-np 1 -H >host names go here< '
            '>binding args go here< '
            '{mpi_flags} '
            '-mca plm_rsh_args "-p 1022" '
            '-mca btl_tcp_if_include eth0,eth1 -x NCCL_SOCKET_IFNAME=eth0,eth1 '
            '--output-filename >output filename goes here< '
            '-x env1 -x env2 '
            '>mpi-extra args go here< '
            'cmd arg1 arg2').format(mpi_flags=' '.join(mpi_flags))
        expected_env = {'env1': 'val1', 'env2': 'val2'}
        run_func.assert_called_once_with(command=expected_command,
                                         env=expected_env,
                                         stdout=stdout,
                                         stderr=stderr)
예제 #7
0
def _task_fn(index, driver_addresses, key, settings, use_gloo):
    # deserialized on Spark workers, settings do not contain the key, so it is given here explicitly
    # Spark RPC communicates the key and supports encryption
    # for convenience, we put it back into settings
    settings.key = key

    task = task_service.SparkTaskService(index, settings.key, settings.nics,
                                         settings.verbose)
    try:
        driver_client = driver_service.SparkDriverClient(
            driver_addresses, settings.key, settings.verbose)
        driver_client.register_task(index, task.addresses(),
                                    host_hash.host_hash())
        task.wait_for_initial_registration(settings.timeout)
        task_indices_on_this_host = driver_client.task_host_hash_indices(
            host_hash.host_hash())

        # With Gloo all tasks wait for the command
        # With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks.
        minimum_lifetime_after_start = None
        if use_gloo or task_indices_on_this_host[0] == index:
            task.wait_for_command_start(settings.timeout)
            minimum_lifetime_after_start = timeout.Timeout(
                MINIMUM_COMMAND_LIFETIME_S, message='Just measuring runtime')
            task.wait_for_command_termination()
        else:
            # The rest of tasks need to wait for the first task to finish.
            first_task_addresses = driver_client.all_task_addresses(
                task_indices_on_this_host[0])
            first_task_client = \
                task_service.SparkTaskClient(task_indices_on_this_host[0],
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            first_task_client.wait_for_command_termination()

        # command terminated, make sure this task service does not shutdown too quickly after
        # the client started the command as it needs some time to connect again
        # to wait for the result after starting the command (see horovod.spark.driver.rsh).
        if minimum_lifetime_after_start is not None:
            time.sleep(minimum_lifetime_after_start.remaining())

        return task.fn_result()
    finally:
        # this has to block on running requests (wait_for_command_exit_code)
        # so they can finish serving the exit code
        # shutdown does block with network.BasicService._server._block_on_close = True
        task.shutdown()
예제 #8
0
def run(fn, args=(), kwargs={}, num_proc=None, start_timeout=None, env=None,
        stdout=None, stderr=None, verbose=1):
    """
    Runs Horovod in Spark.  Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.

    Args:
        fn: Function to run.
        args: Arguments to pass to `fn`.
        kwargs: Keyword arguments to pass to `fn`.
        num_proc: Number of Horovod processes.  Defaults to `spark.default.parallelism`.
        start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
                       If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        env: Environment dictionary to use in Horovod run.  Defaults to `os.environ`.
        stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout.
        stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr.
        verbose: Debug output verbosity (0-2). Defaults to 1.

    Returns:
        List of results returned by running `fn` on each rank.
    """

    if start_timeout is None:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please check that you have '
                                    'enough resources to run all Horovod processes. Each Horovod '
                                    'process runs in a Spark task. You may need to increase the '
                                    'start_timeout parameter to a larger value if your Spark resources '
                                    'are allocated on-demand.')
    settings = hvd_settings.Settings(verbose=verbose,
                                     key=secret.make_secret_key(),
                                     timeout=tmout)

    spark_context = pyspark.SparkContext._active_spark_context
    if spark_context is None:
        raise Exception('Could not find an active SparkContext, are you '
                        'running in a PySpark session?')

    if num_proc is None:
        num_proc = spark_context.defaultParallelism
        if settings.verbose >= 1:
            print('Running %d processes (inferred from spark.default.parallelism)...' % num_proc)
    else:
        if settings.verbose >= 1:
            print('Running %d processes...' % num_proc)
    settings.num_proc = num_proc

    result_queue = queue.Queue(1)

    spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
    driver = driver_service.SparkDriverService(settings.num_proc, fn, args, kwargs,
                                               settings.key)
    spark_thread = _make_spark_thread(spark_context, spark_job_group, driver,
                                      result_queue, settings)
    try:
        driver.wait_for_initial_registration(settings.timeout)
        if settings.verbose >= 2:
            print('Initial Spark task registration is complete.')
        task_clients = [
            task_service.SparkTaskClient(index,
                                         driver.task_addresses_for_driver(index),
                                         settings.key, settings.verbose)
            for index in range(settings.num_proc)]
        for task_client in task_clients:
            task_client.notify_initial_registration_complete()
        driver.wait_for_task_to_task_address_updates(settings.timeout)
        if settings.verbose >= 2:
            print('Spark task-to-task address registration is complete.')

        # Determine a set of common interfaces for task-to-task communication.
        common_intfs = set(driver.task_addresses_for_tasks(0).keys())
        for index in range(1, settings.num_proc):
            common_intfs.intersection_update(driver.task_addresses_for_tasks(index).keys())
        if not common_intfs:
            raise Exception('Unable to find a set of common task-to-task communication interfaces: %s'
                            % [(index, driver.task_addresses_for_tasks(index)) for index in range(settings.num_proc)])

        # Determine the index grouping based on host hashes.
        # Barrel shift until index 0 is in the first host.
        host_hashes = list(driver.task_host_hash_indices().keys())
        host_hashes.sort()
        while 0 not in driver.task_host_hash_indices()[host_hashes[0]]:
            host_hashes = host_hashes[1:] + host_hashes[:1]

        ranks_to_indices = []
        for host_hash in host_hashes:
            ranks_to_indices += driver.task_host_hash_indices()[host_hash]
        driver.set_ranks_to_indices(ranks_to_indices)

        if env is None:
            env = os.environ.copy()

        # Pass secret key through the environment variables.
        env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(settings.key)

        mpirun_command = (
            'mpirun --allow-run-as-root --tag-output '
            '-np {num_proc} -H {hosts} '
            '-bind-to none -map-by slot '
            '-mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include {common_intfs} '
            '-x NCCL_DEBUG=INFO -x NCCL_SOCKET_IFNAME={common_intfs} '
            '{env} '  # expect a lot of environment variables
            '-mca plm_rsh_agent "{python} -m horovod.spark.driver.mpirun_rsh {encoded_driver_addresses} {settings}" '
            '{python} -m horovod.spark.task.mpirun_exec_fn {encoded_driver_addresses} {settings}'
                .format(num_proc=settings.num_proc,
                        hosts=','.join('%s:%d' % (host_hash, len(driver.task_host_hash_indices()[host_hash]))
                                       for host_hash in host_hashes),
                        common_intfs=','.join(common_intfs),
                        env=' '.join('-x %s' % key for key in env.keys() if env_util.is_exportable(key)),
                        python=sys.executable,
                        encoded_driver_addresses=codec.dumps_base64(driver.addresses()),
                        settings=codec.dumps_base64(settings)))
        if settings.verbose >= 2:
            print('+ %s' % mpirun_command)
        exit_code = safe_shell_exec.execute(mpirun_command, env, stdout, stderr)
        if exit_code != 0:
            raise Exception('mpirun exited with code %d, see the error above.' % exit_code)
    except:
        # Terminate Spark job.
        spark_context.cancelJobGroup(spark_job_group)

        # Re-raise exception.
        raise
    finally:
        spark_thread.join()
        driver.shutdown()

    # Make sure Spark Job did not fail.
    driver.check_for_spark_job_failure()

    # If there's no exception, execution results are in this queue.
    results = result_queue.get_nowait()
    return [results[index] for index in ranks_to_indices]
예제 #9
0
def run(fn,
        args=(),
        kwargs={},
        num_proc=None,
        start_timeout=None,
        use_mpi=None,
        use_gloo=None,
        extra_mpi_args=None,
        env=None,
        stdout=None,
        stderr=None,
        verbose=1,
        nics=None):
    """
    Runs Horovod in Spark.  Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.

    Args:
        fn: Function to run.
        args: Arguments to pass to `fn`.
        kwargs: Keyword arguments to pass to `fn`.
        num_proc: Number of Horovod processes.  Defaults to `spark.default.parallelism`.
        start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
                       If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        extra_mpi_args: Extra arguments for mpi_run. Defaults to no extra args.
        env: Environment dictionary to use in Horovod run.  Defaults to `os.environ`.
        stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout.
        stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr.
        verbose: Debug output verbosity (0-2). Defaults to 1.
        nics: List of NICs for tcp network communication.

    Returns:
        List of results returned by running `fn` on each rank.
    """

    if start_timeout is None:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))

    # nics needs to be a set
    if nics and not isinstance(nics, set):
        nics = set(nics)

    tmout = timeout.Timeout(
        start_timeout,
        message='Timed out waiting for {activity}. Please check that you have '
        'enough resources to run all Horovod processes. Each Horovod '
        'process runs in a Spark task. You may need to increase the '
        'start_timeout parameter to a larger value if your Spark resources '
        'are allocated on-demand.')
    settings = hvd_settings.Settings(verbose=verbose,
                                     extra_mpi_args=extra_mpi_args,
                                     key=secret.make_secret_key(),
                                     timeout=tmout,
                                     nics=nics,
                                     run_func_mode=True)

    spark_context = pyspark.SparkContext._active_spark_context
    if spark_context is None:
        raise Exception('Could not find an active SparkContext, are you '
                        'running in a PySpark session?')

    if num_proc is None:
        num_proc = spark_context.defaultParallelism
        if settings.verbose >= 1:
            print(
                'Running %d processes (inferred from spark.default.parallelism)...'
                % num_proc)
    else:
        if settings.verbose >= 1:
            print('Running %d processes...' % num_proc)
    settings.num_proc = num_proc

    result_queue = queue.Queue(1)

    # start Spark driver service and launch settings.num_proc Spark tasks
    spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
    driver = driver_service.SparkDriverService(settings.num_proc, fn, args,
                                               kwargs, settings.key,
                                               settings.nics)
    spark_thread = _make_spark_thread(spark_context, spark_job_group, driver,
                                      result_queue, settings, use_gloo)
    try:
        # wait for all tasks to register and notify them
        driver.wait_for_initial_registration(settings.timeout)
        if settings.verbose >= 2:
            print('Initial Spark task registration is complete.')
        task_clients = [
            task_service.SparkTaskClient(
                index, driver.task_addresses_for_driver(index), settings.key,
                settings.verbose) for index in range(settings.num_proc)
        ]
        for task_client in task_clients:
            task_client.notify_initial_registration_complete()
        driver.wait_for_task_to_task_address_updates(settings.timeout)
        if settings.verbose >= 2:
            print('Spark task-to-task address registration is complete.')

        # Determine the index grouping based on host hashes.
        # Barrel shift until index 0 is in the first host.
        host_hashes = list(driver.task_host_hash_indices().keys())
        host_hashes.sort()
        while 0 not in driver.task_host_hash_indices()[host_hashes[0]]:
            host_hashes = host_hashes[1:] + host_hashes[:1]

        settings.hosts = ','.join(
            '%s:%d' %
            (host_hash, len(driver.task_host_hash_indices()[host_hash]))
            for host_hash in host_hashes)

        # Determine the ranks to indicies
        ranks_to_indices = []
        for host_hash in host_hashes:
            ranks_to_indices += driver.task_host_hash_indices()[host_hash]
        driver.set_ranks_to_indices(ranks_to_indices)

        # Run the job
        _launch_job(use_mpi, use_gloo, settings, driver, env, stdout, stderr)
    except:
        # Terminate Spark job.
        spark_context.cancelJobGroup(spark_job_group)

        # Re-raise exception.
        raise
    finally:
        spark_thread.join()
        driver.shutdown()

    # Make sure Spark Job did not fail.
    driver.check_for_spark_job_failure()

    # If there's no exception, execution results are in this queue.
    results = result_queue.get_nowait()
    return [results[index] for index in ranks_to_indices]
예제 #10
0
파일: run.py 프로젝트: brainhart/horovod
def _run(args):
    if args.check_build:
        check_build(args.verbose)

    # if hosts are not specified, either parse from hostfile, or default as
    # localhost
    if not args.hosts:
        if args.hostfile:
            args.hosts = parse_host_files(args.hostfile)
        else:
            # Set hosts to localhost if not specified
            args.hosts = 'localhost:{np}'.format(np=args.np)

    host_list = args.hosts.split(',')
    all_host_names = []
    pattern = re.compile(r'^[\w.-]+:\d+$')
    for host in host_list:
        if not pattern.match(host.strip()):
            raise ValueError('Invalid host input, please make sure it has '
                             'format as : worker-0:2,worker-1:2.')
        all_host_names.append(host.strip().split(':')[0])

    # horovodrun has to finish all the checks before this timeout runs out.
    if args.start_timeout:
        start_timeout = args.start_timeout
    else:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30'))

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please '
                            'check connectivity between servers. You '
                            'may need to increase the --start-timeout '
                            'parameter if you have too many servers.')
    settings = hvd_settings.Settings(verbose=2 if args.verbose else 0,
                                     ssh_port=args.ssh_port,
                                     ssh_ports=args.ssh_ports,
                                     extra_mpi_args=args.mpi_args,
                                     key=secret.make_secret_key(),
                                     timeout=tmout,
                                     num_hosts=len(all_host_names),
                                     num_proc=args.np,
                                     hosts=args.hosts,
                                     output_filename=args.output_filename,
                                     run_func_mode=args.run_func is not None,
                                     nic=args.nic)

    # This cache stores the results of checks performed by horovodrun
    # during the initialization step. It can be disabled by setting
    # --disable-cache flag.
    fn_cache = None
    if not args.disable_cache:
        params = ''
        if args.np:
            params += str(args.np) + ' '
        if args.hosts:
            params += str(args.hosts) + ' '
        if args.ssh_port:
            params += str(args.ssh_port)
        elif args.ssh_ports:
            params += str(args.ssh_ports)
        parameters_hash = hashlib.md5(params.encode('utf-8')).hexdigest()
        fn_cache = cache.Cache(CACHE_FOLDER, CACHE_STALENESS_THRESHOLD_MINUTES,
                               parameters_hash)

    if settings.verbose >= 2:
        print('Filtering local host names.')
    remote_host_names = network.filter_local_addresses(all_host_names)
    if settings.verbose >= 2:
        print('Remote host found: ' + ' '.join(remote_host_names))

    if len(remote_host_names) > 0:
        if settings.verbose >= 2:
            print('Checking ssh on all remote hosts.')
        # Check if we can ssh into all remote hosts successfully.
        if args.ssh_ports:
            ssh_ports = [
                port for host, port in zip(all_host_names,
                                           args.ssh_ports.split(","))
                if host in set(remote_host_names)
            ]
            ssh_ports = ",".join(ssh_ports)
        else:
            ssh_ports = None
        _check_all_hosts_ssh_successful(remote_host_names,
                                        ssh_port=args.ssh_port,
                                        ssh_ports=ssh_ports,
                                        fn_cache=fn_cache)
        if settings.verbose >= 2:
            print('SSH was successful into all the remote hosts.')

    if len(remote_host_names) > 0:
        if settings.verbose >= 2:
            print('Testing interfaces on all the hosts.')

        local_host_names = set(all_host_names) - set(remote_host_names)
        # Find the set of common, routed interfaces on all the hosts (remote
        # and local) and specify it in the args to be used by NCCL. It is
        # expected that the following function will find at least one interface
        # otherwise, it will raise an exception.
        common_intfs = _driver_fn(all_host_names,
                                  local_host_names,
                                  settings,
                                  fn_cache=fn_cache)

        if settings.verbose >= 2:
            print('Interfaces on all the hosts were successfully checked.')
            print('Common interface found: ' + ' '.join(common_intfs))

    else:
        if settings.verbose >= 2:
            print('All hosts are local, finding the interfaces '
                  'with address 127.0.0.1')
        # If all the given hosts are local, find the interfaces with address
        # 127.0.0.1
        common_intfs = set()
        for iface, addrs in net_if_addrs().items():
            if settings.nic and iface != settings.nic:
                continue
            for addr in addrs:
                if addr.family == AF_INET and addr.address == '127.0.0.1':
                    common_intfs.add(iface)
                    break

        if len(common_intfs) == 0:
            raise ValueError('No interface is found for address 127.0.0.1.')

        if settings.verbose >= 2:
            print('Local interface found ' + ' '.join(common_intfs))

    # get the driver IPv4 address
    driver_ip = _get_driver_ip(common_intfs)

    if args.run_func:
        run_func_server = KVStoreServer(verbose=settings.verbose)
        run_func_server_port = run_func_server.start_server()
        pickled_exec_func = cloudpickle.dumps(args.run_func)
        put_data_into_kvstore(driver_ip, run_func_server_port, 'runfunc',
                              'func', pickled_exec_func)

        command = [
            sys.executable, '-m', 'horovod.run.run_task',
            str(driver_ip),
            str(run_func_server_port)
        ]

        try:
            _launch_job(args, remote_host_names, settings, common_intfs,
                        command)
            results = [None] * args.np
            # TODO: make it parallel to improve performance
            for i in range(args.np):
                pickled_result = read_data_from_kvstore(
                    driver_ip, run_func_server_port, 'runfunc_result', str(i))
                results[i] = cloudpickle.loads(pickled_result)
            return results
        finally:
            run_func_server.shutdown_server()
    else:
        command = args.command
        _launch_job(args, remote_host_names, settings, common_intfs, command)
        return None
예제 #11
0
def _run_static(args):
    all_host_names, _ = parse_hosts_and_slots(args.hosts)

    nics_set = set(args.nics.split(',')) if args.nics else None

    # horovodrun has to finish all the checks before this timeout runs out.
    if args.start_timeout:
        start_timeout = args.start_timeout
    else:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30'))

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please '
                            'check connectivity between servers. You '
                            'may need to increase the --start-timeout '
                            'parameter if you have too many servers.')
    settings = hvd_settings.Settings(verbose=2 if args.verbose else 0,
                                     ssh_port=args.ssh_port,
                                     extra_mpi_args=args.mpi_args,
                                     tcp_flag=args.tcp_flag,
                                     binding_args=args.binding_args,
                                     key=secret.make_secret_key(),
                                     start_timeout=tmout,
                                     num_proc=args.np,
                                     hosts=args.hosts,
                                     num_hosts=len(all_host_names),
                                     output_filename=args.output_filename,
                                     run_func_mode=args.run_func is not None,
                                     nics=nics_set)

    # This cache stores the results of checks performed by horovod
    # during the initialization step. It can be disabled by setting
    # --disable-cache flag.
    fn_cache = None
    if not args.disable_cache:
        params = ''
        if args.np:
            params += str(args.np) + ' '
        if args.hosts:
            params += str(args.hosts) + ' '
        if args.ssh_port:
            params += str(args.ssh_port)
        parameters_hash = hashlib.md5(params.encode('utf-8')).hexdigest()
        fn_cache = cache.Cache(CACHE_FOLDER, CACHE_STALENESS_THRESHOLD_MINUTES,
                               parameters_hash)

    if settings.verbose >= 2:
        print('Filtering local host names.')
    remote_host_names = network.filter_local_addresses(all_host_names)
    if settings.verbose >= 2:
        print('Remote host found: ' + ' '.join(remote_host_names))

    if len(remote_host_names) > 0:
        if settings.verbose >= 2:
            print('Checking ssh on all remote hosts.')
        # Check if we can ssh into all remote hosts successfully.
        _check_all_hosts_ssh_successful(remote_host_names,
                                        args.ssh_port,
                                        fn_cache=fn_cache)
        if settings.verbose >= 2:
            print('SSH was successful into all the remote hosts.')

    nics = driver_service.get_common_interfaces(settings, all_host_names,
                                                remote_host_names, fn_cache)

    if args.run_func:
        # get the driver IPv4 address
        driver_ip = network.get_driver_ip(nics)
        run_func_server = KVStoreServer(verbose=settings.verbose)
        run_func_server_port = run_func_server.start_server()
        put_data_into_kvstore(driver_ip, run_func_server_port, 'runfunc',
                              'func', args.run_func)

        command = [
            sys.executable, '-m', 'horovod.run.run_task',
            str(driver_ip),
            str(run_func_server_port)
        ]

        try:
            _launch_job(args, settings, nics, command)
            results = [None] * args.np
            # TODO: make it parallel to improve performance
            for i in range(args.np):
                results[i] = read_data_from_kvstore(driver_ip,
                                                    run_func_server_port,
                                                    'runfunc_result', str(i))
            return results
        finally:
            run_func_server.shutdown_server()
    else:
        command = args.command
        _launch_job(args, settings, nics, command)
        return None
예제 #12
0
def run():
    args = parse_args()

    if args.version:
        print(horovod.__version__)
        exit(0)

    if args.host:
        all_host_names = [x for x in
                          [y.split(':')[0] for y in args.host.split(',')]]
    else:
        all_host_names = []

    # This cache stores the results of checks performed by horovodrun
    # during the initialization step. It can be disabled by setting
    # --disable-cache flag.
    fn_cache = None
    if not args.disable_cache:
        params = ''
        if args.np:
            params += str(args.np) + ' '
        if args.host:
            params += str(args.host) + ' '
        if args.ssh_port:
            params += str(args.ssh_port)
        parameters_hash = hashlib.md5(params.encode('utf-8')).hexdigest()
        fn_cache = cache.Cache(CACHE_FOLDER, CACHE_STALENESS_THRESHOLD_MINUTES,
                               parameters_hash)

    # horovodrun has to finish all the checks before this timeout runs out.
    if args.start_timeout:
        start_timeout = args.start_timeout
    else:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '600'))
    tmout = timeout.Timeout(start_timeout)

    key = secret.make_secret_key()
    remote_host_names = []
    if args.host:
        if args.verbose:
            print("Filtering local host names.")
        remote_host_names = network.filter_local_addresses(all_host_names)

        if len(remote_host_names) > 0:
            if args.verbose:
                print("Checking ssh on all remote hosts.")
            # Check if we can ssh into all remote hosts successfully.
            _check_all_hosts_ssh_successful(remote_host_names, args.ssh_port,
                                            fn_cache=fn_cache)
            if args.verbose:
                print("SSH was successful into all the remote hosts.")

        hosts_arg = "-H {hosts}".format(hosts=args.host)
    else:
        # if user does not specify any hosts, mpirun by default uses local host.
        # There is no need to specify localhost.
        hosts_arg = ""

    if args.host and len(remote_host_names) > 0:
        if args.verbose:
            print("Testing interfaces on all the hosts.")

        local_host_names = set(all_host_names) - set(remote_host_names)
        # Find the set of common, routed interfaces on all the hosts (remote
        # and local) and specify it in the args to be used by NCCL. It is
        # expected that the following function will find at least one interface
        # otherwise, it will raise an exception.
        common_intfs = _driver_fn(key,
                                  all_host_names, local_host_names, tmout,
                                  args.ssh_port, args.verbose,
                                  fn_cache=fn_cache)

        tcp_intf_arg = "-mca btl_tcp_if_include {common_intfs}".format(
            common_intfs=','.join(common_intfs))
        nccl_socket_intf_arg = "-x NCCL_SOCKET_IFNAME={common_intfs}".format(
            common_intfs=','.join(common_intfs))

        if args.verbose:
            print("Interfaces on all the hosts were successfully checked.")
    else:
        # If all the given hosts are local, no need to specify the interfaces
        # because MPI does not use network for local execution.
        tcp_intf_arg = ""
        nccl_socket_intf_arg = ""

    # Pass all the env variables to the mpirun command.
    env = os.environ.copy()

    # Pass secret key through the environment variables.
    env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(key)

    if not _is_open_mpi_installed():
        raise Exception(
            'horovodrun convenience script currently only supports '
            'Open MPI.\n\n'
            'Choose one of:\n'
            '1. Install Open MPI 4.0.0+ and re-install Horovod '
            '(use --no-cache-dir pip option).\n'
            '2. Run distributed '
            'training script using the standard way provided by your'
            ' MPI distribution (usually mpirun, srun, or jsrun).')

    if args.ssh_port:
        ssh_port_arg = "-mca plm_rsh_args \"-p {ssh_port}\"".format(
            ssh_port=args.ssh_port)
    else:
        ssh_port_arg = ""

    mpirun_command = (
        'mpirun --allow-run-as-root --tag-output '
        '-np {num_proc} {hosts_arg} '
        '-bind-to none -map-by slot '
        '-mca pml ob1 -mca btl ^openib '
        '{ssh_port_arg} '
        '{tcp_intf_arg} '
        '-x NCCL_DEBUG=INFO '
        '{nccl_socket_intf_arg} '
        '{env} {command}'  # expect a lot of environment variables
            .format(num_proc=args.np,
                    hosts_arg=hosts_arg,
                    tcp_intf_arg=tcp_intf_arg,
                    nccl_socket_intf_arg=nccl_socket_intf_arg,
                    ssh_port_arg=ssh_port_arg,
                    env=' '.join('-x %s' % key for key in env.keys()),
                    command=' '.join(quote(par) for par in args.command))
    )

    if args.verbose:
        print(mpirun_command)
    # Execute the mpirun command.
    exit_code = safe_shell_exec.execute(mpirun_command, env)
    if exit_code != 0:
        raise Exception(
            'mpirun exited with code %d, see the error above.' % exit_code)
예제 #13
0
def run():
    args = parse_args()

    if args.check_build:
        check_build(args.verbose)

    # if hosts are not specified, either parse from hostfile, or default as
    # localhost
    if not args.hosts:
        if args.hostfile:
            args.hosts = parse_host_files(args.hostfile)
        else:
            # Set hosts to localhost if not specified
            args.hosts = 'localhost:{np}'.format(np=args.np)

    host_list = args.hosts.split(',')
    all_host_names = []
    pattern = re.compile(r'^[\w.-]+:\d+$')
    for host in host_list:
        if not pattern.match(host.strip()):
            raise ValueError('Invalid host input, please make sure it has '
                             'format as : worker-0:2,worker-1:2.')
        all_host_names.append(host.strip().split(':')[0])

    # horovodrun has to finish all the checks before this timeout runs out.
    if args.start_timeout:
        start_timeout = args.start_timeout
    else:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30'))

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please '
                            'check connectivity between servers. You '
                            'may need to increase the --start-timeout '
                            'parameter if you have too many servers.')
    settings = hvd_settings.Settings(verbose=2 if args.verbose else 0,
                                     ssh_port=args.ssh_port,
                                     key=secret.make_secret_key(),
                                     timeout=tmout,
                                     num_hosts=len(all_host_names),
                                     num_proc=args.np,
                                     hosts=args.hosts,
                                     command=args.command)

    # This cache stores the results of checks performed by horovodrun
    # during the initialization step. It can be disabled by setting
    # --disable-cache flag.
    fn_cache = None
    if not args.disable_cache:
        params = ''
        if args.np:
            params += str(args.np) + ' '
        if args.hosts:
            params += str(args.hosts) + ' '
        if args.ssh_port:
            params += str(args.ssh_port)
        parameters_hash = hashlib.md5(params.encode('utf-8')).hexdigest()
        fn_cache = cache.Cache(CACHE_FOLDER, CACHE_STALENESS_THRESHOLD_MINUTES,
                               parameters_hash)

    if settings.verbose >= 2:
        print('Filtering local host names.')
    remote_host_names = network.filter_local_addresses(all_host_names)
    if settings.verbose >= 2:
        print('Remote host found: ' + ' '.join(remote_host_names))

    if len(remote_host_names) > 0:
        if settings.verbose >= 2:
            print('Checking ssh on all remote hosts.')
        # Check if we can ssh into all remote hosts successfully.
        _check_all_hosts_ssh_successful(remote_host_names,
                                        args.ssh_port,
                                        fn_cache=fn_cache)
        if settings.verbose >= 2:
            print('SSH was successful into all the remote hosts.')

    if len(remote_host_names) > 0:
        if settings.verbose >= 2:
            print('Testing interfaces on all the hosts.')

        local_host_names = set(all_host_names) - set(remote_host_names)
        # Find the set of common, routed interfaces on all the hosts (remote
        # and local) and specify it in the args to be used by NCCL. It is
        # expected that the following function will find at least one interface
        # otherwise, it will raise an exception.
        common_intfs = _driver_fn(all_host_names,
                                  local_host_names,
                                  settings,
                                  fn_cache=fn_cache)

        if settings.verbose >= 2:
            print('Interfaces on all the hosts were successfully checked.')
            print('Common interface found: ' + ' '.join(common_intfs))

    else:
        if settings.verbose >= 2:
            print('All hosts are local, finding the interfaces '
                  'with address 127.0.0.1')
        # If all the given hosts are local, find the interfaces with address
        # 127.0.0.1
        common_intfs = set()
        for iface, addrs in net_if_addrs().items():
            for addr in addrs:
                if addr.family == AF_INET and addr.address == '127.0.0.1':
                    common_intfs.add(iface)
                    break

        if len(common_intfs) == 0:
            raise ValueError('No interface is found for address 127.0.0.1.')

        if settings.verbose >= 2:
            print('Local interface found ' + ' '.join(common_intfs))

    env = os.environ.copy()
    config_parser.set_env_from_args(env, args)

    if args.use_gloo:
        if not gloo_built(verbose=(settings.verbose >= 2)):
            raise ValueError(
                'Gloo support has not been built.  If this is not expected, ensure CMake is installed '
                'and reinstall Horovod with HOROVOD_WITH_GLOO=1 to debug the build error.'
            )
        gloo_run(settings, remote_host_names, common_intfs, env)
    elif args.use_mpi:
        if not mpi_built(verbose=(settings.verbose >= 2)):
            raise ValueError(
                'MPI support has not been built.  If this is not expected, ensure MPI is installed '
                'and reinstall Horovod with HOROVOD_WITH_MPI=1 to debug the build error.'
            )
        mpi_run(settings, common_intfs, env)
    else:
        if mpi_built(verbose=(settings.verbose >= 2)):
            mpi_run(settings, common_intfs, env)
        elif gloo_built(verbose=(settings.verbose >= 2)):
            gloo_run(settings, remote_host_names, common_intfs, env)
        else:
            raise ValueError(
                'Neither MPI nor Gloo support has been built. Try reinstalling Horovod ensuring that '
                'either MPI is installed (MPI) or CMake is installed (Gloo).')
예제 #14
0
def run_elastic(fn,
                args=(),
                kwargs={},
                num_proc=None,
                min_np=None,
                max_np=None,
                start_timeout=None,
                elastic_timeout=None,
                reset_limit=None,
                env=None,
                verbose=1,
                nics=None):
    """
    Runs Elastic Horovod in Spark.  Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.

    Args:
        fn: Function to run.
        args: Arguments to pass to `fn`.
        kwargs: Keyword arguments to pass to `fn`.
        num_proc: Number of Horovod processes.  Defaults to `spark.default.parallelism`.
        start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
                       If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        elastic_timeout: Timeout for elastic initialisation after re-scaling the cluster.
                       If not set, falls back to `HOROVOD_ELASTIC_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        reset_limit: Maximum number of resets after which the job is terminated.
        env: Environment dictionary to use in Horovod run.  Defaults to `os.environ`.
        verbose: Debug output verbosity (0-2). Defaults to 1.
        nics: List of NICs for tcp network communication.

    Returns:
        List of results returned by running `fn` on each rank.
    """

    if not gloo_built(verbose=(verbose >= 2)):
        raise ValueError(
            'Gloo support is required to use elastic training, but has not been built.  Ensure CMake is '
            'installed and reinstall Horovod with HOROVOD_WITH_GLOO=1 to debug the build error.'
        )

    spark_context = pyspark.SparkContext._active_spark_context
    if spark_context is None:
        raise Exception('Could not find an active SparkContext, are you '
                        'running in a PySpark session?')

    if start_timeout is None:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))

    # nics needs to be a set
    if nics and not isinstance(nics, set):
        nics = set(nics)

    if num_proc is None:
        # TODO: #2023 try spark.dynamicAllocation.initialExecutors
        num_proc = spark_context.defaultParallelism
        if verbose >= 1:
            logging.info(
                'Running %d processes (inferred from spark.default.parallelism)...',
                num_proc)
    else:
        if verbose >= 1:
            logging.info('Running %d processes...', num_proc)

    if min_np is None:
        # TODO: #2023 try spark.dynamicAllocation.minExecutors
        min_np = num_proc
    if max_np is None:
        # TODO: #2023 try spark.dynamicAllocation.maxExecutors
        max_np = num_proc

    # start Spark driver service and launch settings.num_proc Spark tasks
    key = secret.make_secret_key()
    spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
    driver = driver_service.SparkDriverService(num_proc, fn, args, kwargs, key,
                                               nics)

    discovery = host_discovery.SparkDriverHostDiscovery(driver)

    tmout = timeout.Timeout(
        start_timeout,
        message='Timed out waiting for {activity}. Please check that you have '
        'enough resources to run all Horovod processes. Each Horovod '
        'process runs in a Spark task. You may need to increase the '
        'start_timeout parameter to a larger value if your Spark resources '
        'are allocated on-demand.')
    settings = hvd_elastic_settings.ElasticSettings(
        discovery=discovery,
        min_np=min_np,
        max_np=max_np,
        elastic_timeout=elastic_timeout,
        reset_limit=reset_limit,
        num_proc=num_proc,
        verbose=verbose,
        key=key,
        start_timeout=tmout,
        nics=nics,
        run_func_mode=True)

    result_queue = queue.Queue(1)

    # launch settings.num_proc / settings.max_np Spark tasks
    spark_thread = _make_spark_thread(spark_context,
                                      spark_job_group,
                                      driver,
                                      result_queue,
                                      settings,
                                      use_gloo=True,
                                      is_elastic=True)
    try:
        # Register task addresses of initial num_proc tasks
        _register_task_addresses(driver, settings)

        # Run the job
        gloo_run_elastic(settings, driver, env)
    except:
        # Terminate Spark job.
        spark_context.cancelJobGroup(spark_job_group)

        # Re-raise exception.
        raise
    finally:
        spark_thread.join()
        driver.shutdown()

    # Make sure Spark Job did not fail.
    driver.check_for_spark_job_failure()

    # get ranks from driver
    indices_in_rank_order = _get_indices_in_rank_order(driver)

    # If there's no exception, execution results are in this queue.
    results = result_queue.get_nowait()
    return [results[index] for index in indices_in_rank_order]
예제 #15
0
파일: run.py 프로젝트: BinLiu2015/horovod
def _run(args):
    if args.check_build:
        check_build(args.verbose)

    # If LSF is used, use default values from job config
    if lsf.LSFUtils.using_lsf():
        if not args.np:
            args.np = lsf.LSFUtils.get_num_processes()
        if not args.hosts and not args.hostfile:
            args.hosts = ','.join(
                '{host}:{np}'.format(host=host, np=lsf.LSFUtils.get_num_gpus())
                for host in lsf.LSFUtils.get_compute_hosts())

    # if hosts are not specified, either parse from hostfile, or default as
    # localhost
    if not args.hosts:
        if args.hostfile:
            args.hosts = parse_host_files(args.hostfile)
        else:
            # Set hosts to localhost if not specified
            args.hosts = 'localhost:{np}'.format(np=args.np)

    host_list = args.hosts.split(',')
    all_host_names = []
    pattern = re.compile(r'^[\w.-]+:\d+$')
    for host in host_list:
        if not pattern.match(host.strip()):
            raise ValueError('Invalid host input, please make sure it has '
                             'format as : worker-0:2,worker-1:2.')
        all_host_names.append(host.strip().split(':')[0])

    # horovodrun has to finish all the checks before this timeout runs out.
    if args.start_timeout:
        start_timeout = args.start_timeout
    else:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30'))

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please '
                            'check connectivity between servers. You '
                            'may need to increase the --start-timeout '
                            'parameter if you have too many servers.')
    settings = hvd_settings.Settings(verbose=2 if args.verbose else 0,
                                     ssh_port=args.ssh_port,
                                     extra_mpi_args=args.mpi_args,
                                     tcp_flag=args.tcp_flag,
                                     binding_args=args.binding_args,
                                     key=secret.make_secret_key(),
                                     timeout=tmout,
                                     num_hosts=len(all_host_names),
                                     num_proc=args.np,
                                     hosts=args.hosts,
                                     output_filename=args.output_filename,
                                     run_func_mode=args.run_func is not None,
                                     nic=args.nic)

    # This cache stores the results of checks performed by horovodrun
    # during the initialization step. It can be disabled by setting
    # --disable-cache flag.
    fn_cache = None
    if not args.disable_cache:
        params = ''
        if args.np:
            params += str(args.np) + ' '
        if args.hosts:
            params += str(args.hosts) + ' '
        if args.ssh_port:
            params += str(args.ssh_port)
        parameters_hash = hashlib.md5(params.encode('utf-8')).hexdigest()
        fn_cache = cache.Cache(CACHE_FOLDER, CACHE_STALENESS_THRESHOLD_MINUTES,
                               parameters_hash)

    if settings.verbose >= 2:
        print('Filtering local host names.')
    remote_host_names = network.filter_local_addresses(all_host_names)
    if settings.verbose >= 2:
        print('Remote host found: ' + ' '.join(remote_host_names))

    if len(remote_host_names) > 0:
        if settings.verbose >= 2:
            print('Checking ssh on all remote hosts.')
        # Check if we can ssh into all remote hosts successfully.
        _check_all_hosts_ssh_successful(remote_host_names,
                                        args.ssh_port,
                                        fn_cache=fn_cache)
        if settings.verbose >= 2:
            print('SSH was successful into all the remote hosts.')

    common_intfs = driver_service._get_common_interfaces(
        settings, all_host_names, remote_host_names, fn_cache)
    if args.run_func:
        # get the driver IPv4 address
        driver_ip = network._get_driver_ip(common_intfs)
        run_func_server = KVStoreServer(verbose=settings.verbose)
        run_func_server_port = run_func_server.start_server()
        pickled_exec_func = cloudpickle.dumps(args.run_func)
        put_data_into_kvstore(driver_ip, run_func_server_port, 'runfunc',
                              'func', pickled_exec_func)

        command = [
            sys.executable, '-m', 'horovod.run.run_task',
            str(driver_ip),
            str(run_func_server_port)
        ]

        try:
            _launch_job(args, remote_host_names, settings, common_intfs,
                        command)
            results = [None] * args.np
            # TODO: make it parallel to improve performance
            for i in range(args.np):
                pickled_result = read_data_from_kvstore(
                    driver_ip, run_func_server_port, 'runfunc_result', str(i))
                results[i] = cloudpickle.loads(pickled_result)
            return results
        finally:
            run_func_server.shutdown_server()
    else:
        command = args.command
        _launch_job(args, remote_host_names, settings, common_intfs, command)
        return None