def run(fn, args=(), kwargs={}, num_proc=None, start_timeout=None, use_mpi=None, use_gloo=None, extra_mpi_args=None, env=None, stdout=None, stderr=None, verbose=1, nics=None): """ Runs Horovod in Spark. Runs `num_proc` processes executing `fn` using the same amount of Spark tasks. Args: fn: Function to run. args: Arguments to pass to `fn`. kwargs: Keyword arguments to pass to `fn`. num_proc: Number of Horovod processes. Defaults to `spark.default.parallelism`. start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds. If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value. If it is not set as well, defaults to 600 seconds. extra_mpi_args: Extra arguments for mpi_run. Defaults to no extra args. env: Environment dictionary to use in Horovod run. Defaults to `os.environ`. stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout. stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr. verbose: Debug output verbosity (0-2). Defaults to 1. nics: List of NICs for tcp network communication. Returns: List of results returned by running `fn` on each rank. """ if start_timeout is None: # Lookup default timeout from the environment variable. start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600')) # nics needs to be a set if nics and not isinstance(nics, set): nics = set(nics) tmout = timeout.Timeout( start_timeout, message='Timed out waiting for {activity}. Please check that you have ' 'enough resources to run all Horovod processes. Each Horovod ' 'process runs in a Spark task. You may need to increase the ' 'start_timeout parameter to a larger value if your Spark resources ' 'are allocated on-demand.') settings = hvd_settings.Settings(verbose=verbose, extra_mpi_args=extra_mpi_args, key=secret.make_secret_key(), timeout=tmout, nics=nics, run_func_mode=True) spark_context = pyspark.SparkContext._active_spark_context if spark_context is None: raise Exception('Could not find an active SparkContext, are you ' 'running in a PySpark session?') if num_proc is None: num_proc = spark_context.defaultParallelism if settings.verbose >= 1: print( 'Running %d processes (inferred from spark.default.parallelism)...' % num_proc) else: if settings.verbose >= 1: print('Running %d processes...' % num_proc) settings.num_proc = num_proc result_queue = queue.Queue(1) # start Spark driver service and launch settings.num_proc Spark tasks spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id() driver = driver_service.SparkDriverService(settings.num_proc, fn, args, kwargs, settings.key, settings.nics) spark_thread = _make_spark_thread(spark_context, spark_job_group, driver, result_queue, settings, use_gloo) try: # wait for all tasks to register and notify them driver.wait_for_initial_registration(settings.timeout) if settings.verbose >= 2: print('Initial Spark task registration is complete.') task_clients = [ task_service.SparkTaskClient( index, driver.task_addresses_for_driver(index), settings.key, settings.verbose) for index in range(settings.num_proc) ] for task_client in task_clients: task_client.notify_initial_registration_complete() driver.wait_for_task_to_task_address_updates(settings.timeout) if settings.verbose >= 2: print('Spark task-to-task address registration is complete.') # Determine the index grouping based on host hashes. # Barrel shift until index 0 is in the first host. host_hashes = list(driver.task_host_hash_indices().keys()) host_hashes.sort() while 0 not in driver.task_host_hash_indices()[host_hashes[0]]: host_hashes = host_hashes[1:] + host_hashes[:1] settings.hosts = ','.join( '%s:%d' % (host_hash, len(driver.task_host_hash_indices()[host_hash])) for host_hash in host_hashes) # Determine the ranks to indicies ranks_to_indices = [] for host_hash in host_hashes: ranks_to_indices += driver.task_host_hash_indices()[host_hash] driver.set_ranks_to_indices(ranks_to_indices) # Run the job _launch_job(use_mpi, use_gloo, settings, driver, env, stdout, stderr) except: # Terminate Spark job. spark_context.cancelJobGroup(spark_job_group) # Re-raise exception. raise finally: spark_thread.join() driver.shutdown() # Make sure Spark Job did not fail. driver.check_for_spark_job_failure() # If there's no exception, execution results are in this queue. results = result_queue.get_nowait() return [results[index] for index in ranks_to_indices]
def run(fn, args=(), kwargs={}, num_proc=None, start_timeout=None, env=None, stdout=None, stderr=None, verbose=1): """ Runs Horovod in Spark. Runs `num_proc` processes executing `fn` using the same amount of Spark tasks. Args: fn: Function to run. args: Arguments to pass to `fn`. kwargs: Keyword arguments to pass to `fn`. num_proc: Number of Horovod processes. Defaults to `spark.default.parallelism`. start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds. If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value. If it is not set as well, defaults to 600 seconds. env: Environment dictionary to use in Horovod run. Defaults to `os.environ`. stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout. stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr. verbose: Debug output verbosity (0-2). Defaults to 1. Returns: List of results returned by running `fn` on each rank. """ if start_timeout is None: # Lookup default timeout from the environment variable. start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600')) tmout = timeout.Timeout(start_timeout, message='Timed out waiting for {activity}. Please check that you have ' 'enough resources to run all Horovod processes. Each Horovod ' 'process runs in a Spark task. You may need to increase the ' 'start_timeout parameter to a larger value if your Spark resources ' 'are allocated on-demand.') settings = hvd_settings.Settings(verbose=verbose, key=secret.make_secret_key(), timeout=tmout) spark_context = pyspark.SparkContext._active_spark_context if spark_context is None: raise Exception('Could not find an active SparkContext, are you ' 'running in a PySpark session?') if num_proc is None: num_proc = spark_context.defaultParallelism if settings.verbose >= 1: print('Running %d processes (inferred from spark.default.parallelism)...' % num_proc) else: if settings.verbose >= 1: print('Running %d processes...' % num_proc) settings.num_proc = num_proc result_queue = queue.Queue(1) spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id() driver = driver_service.SparkDriverService(settings.num_proc, fn, args, kwargs, settings.key) spark_thread = _make_spark_thread(spark_context, spark_job_group, driver, result_queue, settings) try: driver.wait_for_initial_registration(settings.timeout) if settings.verbose >= 2: print('Initial Spark task registration is complete.') task_clients = [ task_service.SparkTaskClient(index, driver.task_addresses_for_driver(index), settings.key, settings.verbose) for index in range(settings.num_proc)] for task_client in task_clients: task_client.notify_initial_registration_complete() driver.wait_for_task_to_task_address_updates(settings.timeout) if settings.verbose >= 2: print('Spark task-to-task address registration is complete.') # Determine a set of common interfaces for task-to-task communication. common_intfs = set(driver.task_addresses_for_tasks(0).keys()) for index in range(1, settings.num_proc): common_intfs.intersection_update(driver.task_addresses_for_tasks(index).keys()) if not common_intfs: raise Exception('Unable to find a set of common task-to-task communication interfaces: %s' % [(index, driver.task_addresses_for_tasks(index)) for index in range(settings.num_proc)]) # Determine the index grouping based on host hashes. # Barrel shift until index 0 is in the first host. host_hashes = list(driver.task_host_hash_indices().keys()) host_hashes.sort() while 0 not in driver.task_host_hash_indices()[host_hashes[0]]: host_hashes = host_hashes[1:] + host_hashes[:1] ranks_to_indices = [] for host_hash in host_hashes: ranks_to_indices += driver.task_host_hash_indices()[host_hash] driver.set_ranks_to_indices(ranks_to_indices) if env is None: env = os.environ.copy() # Pass secret key through the environment variables. env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(settings.key) mpirun_command = ( 'mpirun --allow-run-as-root --tag-output ' '-np {num_proc} -H {hosts} ' '-bind-to none -map-by slot ' '-mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include {common_intfs} ' '-x NCCL_DEBUG=INFO -x NCCL_SOCKET_IFNAME={common_intfs} ' '{env} ' # expect a lot of environment variables '-mca plm_rsh_agent "{python} -m horovod.spark.driver.mpirun_rsh {encoded_driver_addresses} {settings}" ' '{python} -m horovod.spark.task.mpirun_exec_fn {encoded_driver_addresses} {settings}' .format(num_proc=settings.num_proc, hosts=','.join('%s:%d' % (host_hash, len(driver.task_host_hash_indices()[host_hash])) for host_hash in host_hashes), common_intfs=','.join(common_intfs), env=' '.join('-x %s' % key for key in env.keys() if env_util.is_exportable(key)), python=sys.executable, encoded_driver_addresses=codec.dumps_base64(driver.addresses()), settings=codec.dumps_base64(settings))) if settings.verbose >= 2: print('+ %s' % mpirun_command) exit_code = safe_shell_exec.execute(mpirun_command, env, stdout, stderr) if exit_code != 0: raise Exception('mpirun exited with code %d, see the error above.' % exit_code) except: # Terminate Spark job. spark_context.cancelJobGroup(spark_job_group) # Re-raise exception. raise finally: spark_thread.join() driver.shutdown() # Make sure Spark Job did not fail. driver.check_for_spark_job_failure() # If there's no exception, execution results are in this queue. results = result_queue.get_nowait() return [results[index] for index in ranks_to_indices]
def run_elastic( fn, args=(), kwargs={}, num_proc=None, min_num_proc=None, max_num_proc=None, start_timeout=None, elastic_timeout=None, reset_limit=None, env=None, stdout=None, stderr=None, verbose=1, nics=None, prefix_output_with_timestamp=False, # np is deprecated, use min_num_proc instead min_np=None, # max_num_proc is deprecated, use max_num_proc instead max_np=None): """ Runs Elastic Horovod on Spark. Runs `num_proc` processes executing `fn` using the same amount of Spark tasks. Args: fn: Function to run. args: Arguments to pass to `fn`. kwargs: Keyword arguments to pass to `fn`. num_proc: Number of Horovod processes. Defaults to `spark.default.parallelism`. min_num_proc: Minimum number of processes running for training to continue. If number of available processes dips below this threshold, then training will wait for more instances to become available. max_num_proc: Maximum number of training processes, beyond which no additional processes will be created. If not specified, then will be unbounded. start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds. If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value. If it is not set as well, defaults to 600 seconds. elastic_timeout: Timeout for elastic initialisation after re-scaling the cluster. If not set, falls back to `HOROVOD_ELASTIC_TIMEOUT` environment variable value. If it is not set as well, defaults to 600 seconds. reset_limit: Maximum number of resets after which the job is terminated. env: Environment dictionary to use in Horovod run. Defaults to `os.environ`. stdout: Horovod stdout is redirected to this stream. stderr: Horovod stderr is redirected to this stream. verbose: Debug output verbosity (0-2). Defaults to 1. nics: List of NICs for tcp network communication. prefix_output_with_timestamp: shows timestamp in stdout/stderr forwarding on the driver Returns: List of results returned by running `fn` on each rank. """ if min_np is not None: min_num_proc = min_np warnings.warn('min_np is deprecated, use min_num_proc instead', DeprecationWarning) if max_np is not None: max_num_proc = max_np warnings.warn('max_np is deprecated, use max_num_proc instead', DeprecationWarning) if not gloo_built(verbose=(verbose >= 2)): raise ValueError( 'Gloo support is required to use elastic training, but has not been built. Ensure CMake is ' 'installed and reinstall Horovod with HOROVOD_WITH_GLOO=1 to debug the build error.' ) spark_context = pyspark.SparkContext._active_spark_context if spark_context is None: raise Exception('Could not find an active SparkContext, are you ' 'running in a PySpark session?') if start_timeout is None: # Lookup default timeout from the environment variable. start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600')) # nics needs to be a set if nics and not isinstance(nics, set): nics = set(nics) if num_proc is None: # TODO: #2023 try spark.dynamicAllocation.initialExecutors num_proc = spark_context.defaultParallelism if verbose >= 1: logging.info( 'Running %d processes (inferred from spark.default.parallelism)...', num_proc) else: if verbose >= 1: logging.info('Running %d processes...', num_proc) if min_num_proc is None: # TODO: #2023 try spark.dynamicAllocation.minExecutors min_num_proc = num_proc if max_num_proc is None: # TODO: #2023 try spark.dynamicAllocation.maxExecutors max_num_proc = num_proc # start Spark driver service and launch settings.num_proc Spark tasks key = secret.make_secret_key() spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id() driver = driver_service.SparkDriverService(num_proc, max_num_proc, fn, args, kwargs, key, nics) discovery = host_discovery.SparkDriverHostDiscovery(driver) tmout = timeout.Timeout( start_timeout, message='Timed out waiting for {activity}. Please check that you have ' 'enough resources to run all Horovod processes. Each Horovod ' 'process runs in a Spark task. You may need to increase the ' 'start_timeout parameter to a larger value if your Spark resources ' 'are allocated on-demand.') settings = hvd_elastic_settings.ElasticSettings( discovery=discovery, min_num_proc=min_num_proc, max_num_proc=max_num_proc, elastic_timeout=elastic_timeout, reset_limit=reset_limit, num_proc=num_proc, verbose=verbose, key=key, start_timeout=tmout, nics=nics, run_func_mode=True, prefix_output_with_timestamp=prefix_output_with_timestamp) result_queue = queue.Queue(1) # launch settings.num_proc / settings.max_num_proc Spark tasks spark_thread = _make_spark_thread(spark_context, spark_job_group, driver, result_queue, settings, use_gloo=True, is_elastic=True) try: # Register task addresses of initial num_proc tasks _register_task_addresses(driver, settings) # Run the job gloo_run_elastic(settings, driver, env, stdout, stderr) except: # Terminate Spark job. spark_context.cancelJobGroup(spark_job_group) # Re-raise exception. raise finally: spark_thread.join() driver.shutdown() # Make sure Spark Job did not fail. driver.check_for_spark_job_failure() # get ranks from driver indices_in_rank_order = _get_indices_in_rank_order(driver) # If there's no exception, execution results are in this queue. results = result_queue.get_nowait() return [results[index] for index in indices_in_rank_order]