def get_allocation_info(): """Returns and sets the static CSM allocation info.""" if not LSFUtils._csm_allocation_info: lsf_allocation_id = os.environ["CSM_ALLOCATION_ID"].strip() output = io.StringIO() exit_code = safe_shell_exec.execute("{cmd} -a {allocation}".format( cmd=LSFUtils._CSM_ALLOCATION_QUERY, allocation=lsf_allocation_id), stdout=output, stderr=output) if exit_code != 0: raise RuntimeError( "{cmd} failed with exit code {exit_code}".format( cmd=LSFUtils._CSM_ALLOCATION_QUERY, exit_code=exit_code)) LSFUtils._csm_allocation_info = yaml.safe_load(output.getvalue()) # Fetch the total number of cores and gpus for the first host output = io.StringIO() exit_code = safe_shell_exec.execute("{cmd} -n {node}".format( cmd=LSFUtils._CSM_NODE_QUERY, node=LSFUtils._csm_allocation_info["compute_nodes"][0]), stdout=output, stderr=output) if exit_code != 0: raise RuntimeError( "{cmd} failed with exit code {exit_code}".format( cmd=LSFUtils._CSM_NODE_QUERY, exit_code=exit_code)) node_output = yaml.safe_load(output.getvalue()) total_core_count = (int(node_output["Record_1"]["discovered_cores"]) - int(node_output["Record_1"]["discovered_sockets"]) * LSFUtils._csm_allocation_info["isolated_cores"]) LSFUtils._csm_allocation_info["compute_node_cores"]= total_core_count LSFUtils._csm_allocation_info["compute_node_gpus"] = int(node_output["Record_1"]["discovered_gpus"]) # Sorting LSF hostnames LSFUtils._csm_allocation_info["compute_nodes"].sort() return LSFUtils._csm_allocation_info
def do_test_run_with_controller_failure(self, controller, mode, run): if run == 'func': command = None run_func = lambda: fn(0) elif run == 'cmd': command = 'false' run_func = None else: self.fail('unknown run argument {}'.format(run)) if controller == 'mpi': exception = 'mpirun failed with exit code 1' else: exception = 'Horovod detected that one or more processes exited with non-zero status' with self.horovod_args(mode, controller=controller, run_func=run_func, command=command) as (hargs, exec): if controller == 'mpi' and run == 'cmd': self.assertIsNone(_run(hargs)) exec.assert_called_once() args, kwargs = exec.call_args executable, args, env = args self.assertEqual('/bin/sh', executable) self.assertEqual(3, len(args)) self.assertEqual('/bin/sh', args[0]) self.assertEqual('-c', args[1]) exit_code = safe_shell_exec.execute(args[2], env) self.assertEqual(1, exit_code) else: with pytest.raises(RuntimeError, match=exception): _run(hargs)
def _exec_command(command, index, event): if settings.verbose: print(command) # Redirect output if requested stdout = stderr = None stdout_file = stderr_file = None if settings.output_filename: padded_rank = _pad_rank(index, settings.num_proc) output_dir_rank = os.path.join(settings.output_filename, 'rank.{rank}'.format(rank=padded_rank)) if not os.path.exists(output_dir_rank): os.mkdir(output_dir_rank) stdout_file = open(os.path.join(output_dir_rank, 'stdout'), 'w') stderr_file = open(os.path.join(output_dir_rank, 'stderr'), 'w') stdout = MultiFile([sys.stdout, stdout_file]) stderr = MultiFile([sys.stderr, stderr_file]) try: exit_code = safe_shell_exec.execute(command, index=index, event=event, stdout=stdout, stderr=stderr) if exit_code != 0: print('Process {idx} exit with status code {ec}.'.format(idx=index, ec=exit_code)) except Exception as e: print('Exception happened during safe_shell_exec, exception ' 'message: {message}'.format(message=e)) exit_code = 1 finally: if stdout_file: stdout_file.close() if stderr_file: stderr_file.close() return exit_code, time.time()
def _is_open_mpi_installed(): output = six.StringIO() command = 'mpirun --version' try: exit_code = safe_shell_exec.execute(command, stdout=output, stderr=output) output_msg = output.getvalue() except Exception: print(traceback.format_exc(), file=sys.stderr) return False finally: output.close() if exit_code == 0: if 'Open MPI' not in output_msg: print('Open MPI not found in output of mpirun --version.', file=sys.stderr) return False else: return True else: print("Was not able to run %s:\n%s" % (command, output_msg), file=sys.stderr) return False
def _get_mpi_implementation_flags(): output = six.StringIO() command = 'mpirun --version' try: exit_code = safe_shell_exec.execute(command, stdout=output, stderr=output) output_msg = output.getvalue() except Exception: print(traceback.format_exc(), file=sys.stderr) return None finally: output.close() if exit_code == 0: if 'Open MPI' in output_msg: return list(_OMPI_FLAGS) elif 'IBM Spectrum MPI' in output_msg: return list(_SMPI_FLAGS) elif 'MPICH' in output_msg: return list(_MPICH_FLAGS) print( 'Open MPI/Spectrum MPI/MPICH not found in output of mpirun --version.', file=sys.stderr) return None else: print("Was not able to run %s:\n%s" % (command, output_msg), file=sys.stderr) return None
def do_test_run_with_controller_success(self, controller, mode, run): if run == 'func': command = None run_func = fn elif run == 'cmd': command = 'true' run_func = None else: self.fail('unknown run argument {}'.format(run)) with self.horovod_args(mode, controller, run_func=run_func, command=command) as (hargs, exec): if controller == 'mpi' and run == 'cmd': self.assertIsNone(_run(hargs)) exec.assert_called_once() args, kwargs = exec.call_args executable, args, env = args self.assertEqual('/bin/sh', executable) self.assertEqual(3, len(args)) self.assertEqual('/bin/sh', args[0]) self.assertEqual('-c', args[1]) exit_code = safe_shell_exec.execute(args[2], env) self.assertEqual(0, exit_code) else: actual = _run(hargs) expected = list([(rank, hargs.np) for rank in range(hargs.np)]) if run == 'func' else None self.assertEqual(expected, actual)
def _exec_command(command, slot_info, events): index = slot_info.rank host_name = slot_info.hostname host_address = network.resolve_host_address(host_name) local_addresses = network.get_local_host_addresses() if host_address not in local_addresses: command = 'ssh -o PasswordAuthentication=no -o StrictHostKeyChecking=no ' \ '{host} {ssh_port_arg} ' \ '{local_command}'\ .format(host=host_name, ssh_port_arg=ssh_port_arg, local_command=quote('cd {pwd} > /dev/null 2>&1 ; {local_command}' .format(pwd=os.getcwd(), local_command=command))) if settings.verbose: print(command) # Redirect output if requested stdout = stderr = None stdout_file = stderr_file = None if settings.output_filename: padded_rank = _pad_rank(index, settings.num_proc) output_dir_rank = os.path.join( settings.output_filename, 'rank.{rank}'.format(rank=padded_rank)) if not os.path.exists(output_dir_rank): os.mkdir(output_dir_rank) stdout_file = open(os.path.join(output_dir_rank, 'stdout'), 'w') stderr_file = open(os.path.join(output_dir_rank, 'stderr'), 'w') stdout = MultiFile([sys.stdout, stdout_file]) stderr = MultiFile([sys.stderr, stderr_file]) try: exit_code = safe_shell_exec.execute(command, index=index, stdout=stdout, stderr=stderr, events=events) if exit_code != 0: print('Process {idx} exit with status code {ec}.'.format( idx=index, ec=exit_code)) except Exception as e: print('Exception happened during safe_shell_exec, exception ' 'message: {message}'.format(message=e)) exit_code = 1 finally: if stdout_file: stdout_file.close() if stderr_file: stderr_file.close() return exit_code, time.time()
def _exec_command(_command, _index, event_): if settings.verbose: print(_command) try: exit_code = safe_shell_exec.execute(_command, index=_index, event=event_) if exit_code != 0: print('Process {idx} exit with status code {ec}.'.format(idx=_index, ec=exit_code)) except Exception as e: print('Exception happened during safe_shell_exec, exception ' 'message: {message}'.format(message=e)) return 0
def get_num_threads(): """Returns the number of hardware threads.""" lscpu_cmd = 'ssh -o StrictHostKeyChecking=no {host} {cmd}'.format( host=LSFUtils.get_compute_hosts()[0], cmd=LSFUtils._LSCPU_CMD ) output = io.StringIO() exit_code = safe_shell_exec.execute(lscpu_cmd, stdout=output, stderr=output) if exit_code != 0: raise RuntimeError("{cmd} failed with exit code {exit_code}".format( cmd=lscpu_cmd, exit_code=exit_code)) return int(yaml.safe_load(output.getvalue())[LSFUtils._THREAD_KEY])
def _exec_command(command): host_output = six.StringIO() try: exit_code = safe_shell_exec.execute(command, stdout=host_output, stderr=host_output) if exit_code != 0: print('Launching horovodrun task function was not ' 'successful:\n{host_output}'.format( host_output=host_output.getvalue())) os._exit(exit_code) finally: host_output.close() return exit_code
def _do_test_safe_shell_exec(self, cmd, expected_exit_code, expected_stdout, expected_stderr, event=None): stdout = six.StringIO() stderr = six.StringIO() res = safe_shell_exec.execute(cmd, stdout=stdout, stderr=stderr, event=event) self.assertEqual(expected_exit_code, res) self.assertEqual(expected_stdout, stdout.getvalue()) self.assertEqual(expected_stderr, stderr.getvalue())
def execute(command): """ Executes the command and returns stdout and stderr as a string, together with the exit code. :param command: command to execute :return: (output, exit code) or None on failure """ output = io.StringIO() try: exit_code = safe_shell_exec.execute(command, stdout=output, stderr=output) output_msg = output.getvalue() except Exception: print(traceback.format_exc(), file=sys.stderr) return None finally: output.close() return output_msg, exit_code
def do_test_safe_shell_exec(self, cmd, expected_exit_code, expected_stdout, expected_stderr, event=None): stdout = io.StringIO() stderr = io.StringIO() res = safe_shell_exec.execute(cmd, stdout=stdout, stderr=stderr, events=[event]) self.assertEqual(expected_exit_code, res) if expected_stdout is not None: self.assertEqual(expected_stdout, stdout.getvalue()) if expected_stderr is not None: self.assertEqual(expected_stderr, stderr.getvalue())
def exec_command(command): exit_code = 1 output_msg = '' # Try ssh 5 times for i in range(SSH_ATTEMPTS): output = six.StringIO() try: exit_code = safe_shell_exec.execute(command, stdout=output, stderr=output) if exit_code == 0: break output_msg = output.getvalue() finally: output.close() return exit_code, output_msg
def find_available_hosts_and_slots(self): stdout = io.StringIO() exit_code = safe_shell_exec.execute(self._discovery_script, stdout=stdout) if exit_code != 0: raise RuntimeError( 'Failed to execute discovery script: {}. Exit code: {}'.format( self._discovery_script, exit_code)) host_slots = {} lines = set(stdout.getvalue().strip().split('\n')) for line in lines: host = line if ':' in line: host, slots = line.split(':') host_slots[host] = int(slots) else: host_slots[host] = self._default_slots return host_slots
def run(): args = parse_args() if args.version: print(horovod.__version__) exit(0) if args.host: all_host_names = [x for x in [y.split(':')[0] for y in args.host.split(',')]] else: all_host_names = [] # This cache stores the results of checks performed by horovodrun # during the initialization step. It can be disabled by setting # --disable-cache flag. fn_cache = None if not args.disable_cache: params = '' if args.np: params += str(args.np) + ' ' if args.host: params += str(args.host) + ' ' if args.ssh_port: params += str(args.ssh_port) parameters_hash = hashlib.md5(params.encode('utf-8')).hexdigest() fn_cache = cache.Cache(CACHE_FOLDER, CACHE_STALENESS_THRESHOLD_MINUTES, parameters_hash) # horovodrun has to finish all the checks before this timeout runs out. if args.start_timeout: start_timeout = args.start_timeout else: # Lookup default timeout from the environment variable. start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '600')) tmout = timeout.Timeout(start_timeout) key = secret.make_secret_key() remote_host_names = [] if args.host: if args.verbose: print("Filtering local host names.") remote_host_names = network.filter_local_addresses(all_host_names) if len(remote_host_names) > 0: if args.verbose: print("Checking ssh on all remote hosts.") # Check if we can ssh into all remote hosts successfully. _check_all_hosts_ssh_successful(remote_host_names, args.ssh_port, fn_cache=fn_cache) if args.verbose: print("SSH was successful into all the remote hosts.") hosts_arg = "-H {hosts}".format(hosts=args.host) else: # if user does not specify any hosts, mpirun by default uses local host. # There is no need to specify localhost. hosts_arg = "" if args.host and len(remote_host_names) > 0: if args.verbose: print("Testing interfaces on all the hosts.") local_host_names = set(all_host_names) - set(remote_host_names) # Find the set of common, routed interfaces on all the hosts (remote # and local) and specify it in the args to be used by NCCL. It is # expected that the following function will find at least one interface # otherwise, it will raise an exception. common_intfs = _driver_fn(key, all_host_names, local_host_names, tmout, args.ssh_port, args.verbose, fn_cache=fn_cache) tcp_intf_arg = "-mca btl_tcp_if_include {common_intfs}".format( common_intfs=','.join(common_intfs)) nccl_socket_intf_arg = "-x NCCL_SOCKET_IFNAME={common_intfs}".format( common_intfs=','.join(common_intfs)) if args.verbose: print("Interfaces on all the hosts were successfully checked.") else: # If all the given hosts are local, no need to specify the interfaces # because MPI does not use network for local execution. tcp_intf_arg = "" nccl_socket_intf_arg = "" # Pass all the env variables to the mpirun command. env = os.environ.copy() # Pass secret key through the environment variables. env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(key) if not _is_open_mpi_installed(): raise Exception( 'horovodrun convenience script currently only supports ' 'Open MPI.\n\n' 'Choose one of:\n' '1. Install Open MPI 4.0.0+ and re-install Horovod ' '(use --no-cache-dir pip option).\n' '2. Run distributed ' 'training script using the standard way provided by your' ' MPI distribution (usually mpirun, srun, or jsrun).') if args.ssh_port: ssh_port_arg = "-mca plm_rsh_args \"-p {ssh_port}\"".format( ssh_port=args.ssh_port) else: ssh_port_arg = "" mpirun_command = ( 'mpirun --allow-run-as-root --tag-output ' '-np {num_proc} {hosts_arg} ' '-bind-to none -map-by slot ' '-mca pml ob1 -mca btl ^openib ' '{ssh_port_arg} ' '{tcp_intf_arg} ' '-x NCCL_DEBUG=INFO ' '{nccl_socket_intf_arg} ' '{env} {command}' # expect a lot of environment variables .format(num_proc=args.np, hosts_arg=hosts_arg, tcp_intf_arg=tcp_intf_arg, nccl_socket_intf_arg=nccl_socket_intf_arg, ssh_port_arg=ssh_port_arg, env=' '.join('-x %s' % key for key in env.keys()), command=' '.join(quote(par) for par in args.command)) ) if args.verbose: print(mpirun_command) # Execute the mpirun command. exit_code = safe_shell_exec.execute(mpirun_command, env) if exit_code != 0: raise Exception( 'mpirun exited with code %d, see the error above.' % exit_code)
import sys import time from horovod.run.common.util import safe_shell_exec class FakeEvent(object): def wait(self): time.sleep(999) def write(filename, value): filename_tmp = filename + '.tmp' with open(filename_tmp, 'w') as f: f.write(str(value)) # Atomic rename to prevent race conditions from reader os.rename(filename_tmp, filename) if __name__ == '__main__': logfile = sys.argv[1] write(logfile, os.getpid()) cmd = ' '.join([sys.executable] + sys.argv[2:]) # Mock out the event to avoid leaking semaphores safe_shell_exec._create_event = lambda ctx: FakeEvent() safe_shell_exec.execute(cmd)
def _run_command(self, command, env, event): self._command_exit_code = safe_shell_exec.execute(command, env=env, events=[event])
def js_run(settings, common_intfs, env, command, stdout=None, stderr=None, run_func=safe_shell_exec.execute): """ Runs Horovod with jsrun. Args: settings: Settings for running jsrun. Note: settings.num_proc and settings.hosts must not be None. common_intfs: Interfaces to include by jsrun. env: Environment dictionary to use for running jsrun. command: Command and arguments to run as a list of string. stdout: Stdout of the mpi process. Only used when settings.run_func_mode is True. stderr: Stderr of the mpi process. Only used when settings.run_func_mode is True. run_func: Run function to use. Must have arguments 'command' and 'env'. Only used when settings.run_func_mode is True. Defaults to safe_shell_exec.execute. """ mpi_impl_flags, _ = _get_mpi_implementation_flags(settings.tcp_flag) if mpi_impl_flags is None: raise Exception(_MPI_NOT_FOUND_ERROR_MSG) if not is_jsrun_installed(): raise Exception( 'horovodrun convenience script does not find the jsrun command.\n\n' 'Please, make sure you are running on a cluster with jsrun installed or ' 'use one of the other launchers.') if common_intfs and 'NCCL_SOCKET_IFNAME' not in env: env['NCCL_SOCKET_IFNAME'] = ','.join(common_intfs) smpiargs = ' '.join(mpi_impl_flags) if settings.extra_mpi_args: smpiargs += ' ' + settings.extra_mpi_args if settings.binding_args: binding_args = settings.binding_args else: rf = generate_jsrun_rankfile(settings) if settings.verbose >= 2: safe_shell_exec.execute('cat {rf}'.format(rf=rf)) binding_args = '--erf_input {rf}'.format(rf=rf) jsrun_command = ( 'jsrun {binding_args} ' '{output_filename_arg} ' '{smpiargs} ' '{command}'.format( binding_args=binding_args, output_filename_arg='--stdio_stderr {file} --stdio_stdout {file}'. format(file=settings.output_filename) if settings.output_filename else '', smpiargs='--smpiargs {args}'.format( args=quote(smpiargs)) if smpiargs else '', command=' '.join(quote(par) for par in command))) if settings.verbose >= 2: print(jsrun_command) # Execute the jsrun command. if settings.run_func_mode: exit_code = run_func(command=jsrun_command, env=env, stdout=stdout, stderr=stderr) if exit_code != 0: raise RuntimeError( "jsrun failed with exit code {exit_code}".format( exit_code=exit_code)) else: os.execve('/bin/sh', ['/bin/sh', '-c', jsrun_command], env)
def mpi_run(settings, nics, env, command, stdout=None, stderr=None): """ Runs mpi_run. Args: settings: Settings for running MPI. Note: settings.num_proc and settings.hosts must not be None. nics: Interfaces to include by MPI. env: Environment dictionary to use for running command. command: Command and arguments to run as a list of string. stdout: Stdout of the mpi process. Only used when settings.run_func_mode is True. stderr: Stderr of the mpi process. Only used when settings.run_func_mode is True. """ if env is not None and not isinstance(env, dict): raise Exception( 'env argument must be a dict, not {type}: {env}'.format( type=type(env), env=env)) mpi_impl_flags, impl_binding_args = _get_mpi_implementation_flags( settings.tcp_flag, env=env) if mpi_impl_flags is None: raise Exception(_MPI_NOT_FOUND_ERROR_MSG) ssh_port_arg = '-mca plm_rsh_args \"-p {ssh_port}\"'.format( ssh_port=settings.ssh_port) if settings.ssh_port else '' # if user does not specify any hosts, mpirun by default uses local host. # There is no need to specify localhost. hosts_arg = '-H {hosts}'.format(hosts=settings.hosts) tcp_intf_arg = '-mca btl_tcp_if_include {nics}'.format( nics=','.join(nics)) if nics else '' nccl_socket_intf_arg = '-x NCCL_SOCKET_IFNAME={nics}'.format( nics=','.join(nics)) if nics else '' # On large cluster runs (e.g. Summit), we need extra settings to work around OpenMPI issues if settings.num_hosts and settings.num_hosts >= _LARGE_CLUSTER_THRESHOLD: mpi_impl_flags.append('-mca plm_rsh_no_tree_spawn true') mpi_impl_flags.append('-mca plm_rsh_num_concurrent {}'.format( settings.num_hosts)) binding_args = settings.binding_args if settings.binding_args else ' '.join( impl_binding_args) # Pass all the env variables to the mpirun command. mpirun_command = ( 'mpirun --allow-run-as-root --tag-output ' '-np {num_proc} {hosts_arg} ' '{binding_args} ' '{mpi_args} ' '{ssh_port_arg} ' '{tcp_intf_arg} ' '{nccl_socket_intf_arg} ' '{output_filename_arg} ' '{env} {extra_mpi_args} {command}' # expect a lot of environment variables .format(num_proc=settings.num_proc, hosts_arg=hosts_arg, binding_args=binding_args, mpi_args=' '.join(mpi_impl_flags), tcp_intf_arg=tcp_intf_arg, nccl_socket_intf_arg=nccl_socket_intf_arg, ssh_port_arg=ssh_port_arg, output_filename_arg='--output-filename ' + settings.output_filename if settings.output_filename else '', env=' '.join('-x %s' % key for key in sorted(env.keys()) if env_util.is_exportable(key)), extra_mpi_args=settings.extra_mpi_args if settings.extra_mpi_args else '', command=' '.join(quote(par) for par in command))) if settings.verbose >= 2: print(mpirun_command) # we need the driver's PATH in env to run mpirun, # env for mpirun is different to env encoded in mpirun_command if 'PATH' not in env and 'PATH' in os.environ: env = copy.copy(env) # copy env so we do not leak env modifications env['PATH'] = os.environ['PATH'] # Execute the mpirun command. if settings.run_func_mode: exit_code = safe_shell_exec.execute(mpirun_command, env=env, stdout=stdout, stderr=stderr) if exit_code != 0: raise RuntimeError( "mpirun failed with exit code {exit_code}".format( exit_code=exit_code)) else: os.execve('/bin/sh', ['/bin/sh', '-c', mpirun_command], env)
def _exec(cmd): exit_code = safe_shell_exec.execute(cmd) if exit_code is None or exit_code != 0: raise RuntimeError( 'executed command returned non-zero exit code: {}'.format( exit_code))
def run(fn, args=(), kwargs={}, num_proc=None, start_timeout=None, env=None, stdout=None, stderr=None, verbose=1): """ Runs Horovod in Spark. Runs `num_proc` processes executing `fn` using the same amount of Spark tasks. Args: fn: Function to run. args: Arguments to pass to `fn`. kwargs: Keyword arguments to pass to `fn`. num_proc: Number of Horovod processes. Defaults to `spark.default.parallelism`. start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds. If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value. If it is not set as well, defaults to 600 seconds. env: Environment dictionary to use in Horovod run. Defaults to `os.environ`. stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout. stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr. verbose: Debug output verbosity (0-2). Defaults to 1. Returns: List of results returned by running `fn` on each rank. """ if start_timeout is None: # Lookup default timeout from the environment variable. start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600')) tmout = timeout.Timeout(start_timeout, message='Timed out waiting for {activity}. Please check that you have ' 'enough resources to run all Horovod processes. Each Horovod ' 'process runs in a Spark task. You may need to increase the ' 'start_timeout parameter to a larger value if your Spark resources ' 'are allocated on-demand.') settings = hvd_settings.Settings(verbose=verbose, key=secret.make_secret_key(), timeout=tmout) spark_context = pyspark.SparkContext._active_spark_context if spark_context is None: raise Exception('Could not find an active SparkContext, are you ' 'running in a PySpark session?') if num_proc is None: num_proc = spark_context.defaultParallelism if settings.verbose >= 1: print('Running %d processes (inferred from spark.default.parallelism)...' % num_proc) else: if settings.verbose >= 1: print('Running %d processes...' % num_proc) settings.num_proc = num_proc result_queue = queue.Queue(1) spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id() driver = driver_service.SparkDriverService(settings.num_proc, fn, args, kwargs, settings.key) spark_thread = _make_spark_thread(spark_context, spark_job_group, driver, result_queue, settings) try: driver.wait_for_initial_registration(settings.timeout) if settings.verbose >= 2: print('Initial Spark task registration is complete.') task_clients = [ task_service.SparkTaskClient(index, driver.task_addresses_for_driver(index), settings.key, settings.verbose) for index in range(settings.num_proc)] for task_client in task_clients: task_client.notify_initial_registration_complete() driver.wait_for_task_to_task_address_updates(settings.timeout) if settings.verbose >= 2: print('Spark task-to-task address registration is complete.') # Determine a set of common interfaces for task-to-task communication. common_intfs = set(driver.task_addresses_for_tasks(0).keys()) for index in range(1, settings.num_proc): common_intfs.intersection_update(driver.task_addresses_for_tasks(index).keys()) if not common_intfs: raise Exception('Unable to find a set of common task-to-task communication interfaces: %s' % [(index, driver.task_addresses_for_tasks(index)) for index in range(settings.num_proc)]) # Determine the index grouping based on host hashes. # Barrel shift until index 0 is in the first host. host_hashes = list(driver.task_host_hash_indices().keys()) host_hashes.sort() while 0 not in driver.task_host_hash_indices()[host_hashes[0]]: host_hashes = host_hashes[1:] + host_hashes[:1] ranks_to_indices = [] for host_hash in host_hashes: ranks_to_indices += driver.task_host_hash_indices()[host_hash] driver.set_ranks_to_indices(ranks_to_indices) if env is None: env = os.environ.copy() # Pass secret key through the environment variables. env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(settings.key) mpirun_command = ( 'mpirun --allow-run-as-root --tag-output ' '-np {num_proc} -H {hosts} ' '-bind-to none -map-by slot ' '-mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include {common_intfs} ' '-x NCCL_DEBUG=INFO -x NCCL_SOCKET_IFNAME={common_intfs} ' '{env} ' # expect a lot of environment variables '-mca plm_rsh_agent "{python} -m horovod.spark.driver.mpirun_rsh {encoded_driver_addresses} {settings}" ' '{python} -m horovod.spark.task.mpirun_exec_fn {encoded_driver_addresses} {settings}' .format(num_proc=settings.num_proc, hosts=','.join('%s:%d' % (host_hash, len(driver.task_host_hash_indices()[host_hash])) for host_hash in host_hashes), common_intfs=','.join(common_intfs), env=' '.join('-x %s' % key for key in env.keys() if env_util.is_exportable(key)), python=sys.executable, encoded_driver_addresses=codec.dumps_base64(driver.addresses()), settings=codec.dumps_base64(settings))) if settings.verbose >= 2: print('+ %s' % mpirun_command) exit_code = safe_shell_exec.execute(mpirun_command, env, stdout, stderr) if exit_code != 0: raise Exception('mpirun exited with code %d, see the error above.' % exit_code) except: # Terminate Spark job. spark_context.cancelJobGroup(spark_job_group) # Re-raise exception. raise finally: spark_thread.join() driver.shutdown() # Make sure Spark Job did not fail. driver.check_for_spark_job_failure() # If there's no exception, execution results are in this queue. results = result_queue.get_nowait() return [results[index] for index in ranks_to_indices]
def mpi_run(settings, common_intfs, env, command): mpi_impl_flags = _get_mpi_implementation_flags() if mpi_impl_flags is None: raise Exception( 'horovodrun convenience script does not find an installed MPI.\n\n' 'Choose one of:\n' '1. Install Open MPI 4.0.0+ or IBM Spectrum MPI or MPICH and re-install Horovod ' '(use --no-cache-dir pip option).\n' '2. Run distributed ' 'training script using the standard way provided by your' ' MPI distribution (usually mpirun, srun, or jsrun).\n' '3. Use built-in gloo option (horovodrun --gloo ...).') ssh_port_arg = '-mca plm_rsh_args \"-p {ssh_port}\"'.format( ssh_port=settings.ssh_port) if settings.ssh_port else '' # if user does not specify any hosts, mpirun by default uses local host. # There is no need to specify localhost. hosts_arg = '-H {hosts}'.format(hosts=settings.hosts) tcp_intf_arg = '-mca btl_tcp_if_include {common_intfs}'.format( common_intfs=','.join(common_intfs)) if common_intfs else '' nccl_socket_intf_arg = '-x NCCL_SOCKET_IFNAME={common_intfs}'.format( common_intfs=','.join(common_intfs)) if common_intfs else '' # On large cluster runs (e.g. Summit), we need extra settings to work around OpenMPI issues if settings.num_hosts >= 64: mpi_impl_flags.append('-mca plm_rsh_no_tree_spawn true') mpi_impl_flags.append('-mca plm_rsh_num_concurrent {}'.format( settings.num_proc)) # Pass all the env variables to the mpirun command. mpirun_command = ( 'mpirun --allow-run-as-root --tag-output ' '-np {num_proc} {hosts_arg} ' '-bind-to none -map-by slot ' '{mpi_args} ' '{ssh_port_arg} ' '{tcp_intf_arg} ' '{nccl_socket_intf_arg} ' '{output_filename_arg} ' '{env} {extra_mpi_args} {command}' # expect a lot of environment variables .format(num_proc=settings.num_proc, hosts_arg=hosts_arg, mpi_args=' '.join(mpi_impl_flags), tcp_intf_arg=tcp_intf_arg, nccl_socket_intf_arg=nccl_socket_intf_arg, ssh_port_arg=ssh_port_arg, output_filename_arg='--output-filename ' + settings.output_filename if settings.output_filename else '', env=' '.join('-x %s' % key for key in env.keys() if env_util.is_exportable(key)), extra_mpi_args=settings.extra_mpi_args if settings.extra_mpi_args else '', command=' '.join(quote(par) for par in command))) if settings.verbose >= 2: print(mpirun_command) # Execute the mpirun command. if settings.run_func_mode: exit_code = safe_shell_exec.execute(mpirun_command, env=env) if exit_code != 0: raise RuntimeError( "mpirun failed with exit code {exit_code}".format( exit_code=exit_code)) else: os.execve('/bin/sh', ['/bin/sh', '-c', mpirun_command], env)