def test_shutdown_during_request_basic(self): sleep = 2.0 key = secret.make_secret_key() service = TestSleepService(key, duration=sleep) try: client = TestSleepClient(service.addresses(), key, attempts=1) start = time.time() threads = list([ in_thread(client.sleep, name='request {}'.format(i + 1), daemon=False) for i in range(5) ]) time.sleep(sleep / 2.0) finally: service.shutdown() duration = time.time() - start print('shutdown completed in {} seconds'.format(duration)) self.assertGreaterEqual(duration, sleep, 'sleep requests should have been completed') self.assertLess(duration, sleep + 1.0, 'sleep requests should have been concurrent') for thread in threads: thread.join(0.1) self.assertFalse(thread.is_alive(), 'thread should have terminated by now')
def test_invalid_dispatcher_ids(self): key = secret.make_secret_key() service = ComputeService(2, 4, key, nics=None) try: client = ComputeClient(service.addresses(), key, verbose=2) with self.assertRaises(IndexError): client.register_dispatcher(-1, 'grpc://localhost:10000') with self.assertRaises(IndexError): client.register_dispatcher(2, 'grpc://localhost:10000') with self.assertRaises(IndexError): client.wait_for_dispatcher_registration(-1, 0.1) with self.assertRaises(IndexError): client.wait_for_dispatcher_registration(2, 0.1) with self.assertRaises(IndexError): client.register_worker_for_dispatcher(-1, 0) with self.assertRaises(IndexError): client.register_worker_for_dispatcher(2, 0) with self.assertRaises(IndexError): client.wait_for_dispatcher_worker_registration(-1, 0.1) with self.assertRaises(IndexError): client.wait_for_dispatcher_worker_registration(2, 0.1) finally: service.shutdown()
def test_shutdown_during_request_basic_task(self): result_queue = queue.Queue(1) def wait_for_exit_code(client, queue): queue.put(client.wait_for_command_exit_code()) key = secret.make_secret_key() service_name = 'test-service' service = BasicTaskService(service_name, key, nics=None, verbose=2) client = BasicTaskClient(service_name, service.addresses(), key, verbose=2, attempts=1) thread = threading.Thread(target=wait_for_exit_code, args=(client, result_queue)) start = time.time() thread.start() # wait for command exit code client.run_command('sleep 2', {}) # execute command time.sleep(0.5) # give the thread some time to connect before shutdown service.shutdown() # shutdown should wait on request to finish duration = time.time() - start self.assertGreaterEqual(duration, 2) # we cannot call after shutdown with pytest.raises(Exception, match=r'^(\[[Ee]rrno 104\] Connection reset by peer)' r'|(\[[Ee]rrno 111\] Connection refused)$'): client.command_result() # but still our long running request succeeded thread.join(1.0) self.assertFalse(thread.is_alive())
def init(self, rendezvous_addr=None, rendezvous_port=None, nic=None, hostname=None, local_rank=None): with self._lock: if self._service: return rendezvous_addr = rendezvous_addr or os.environ.get( HOROVOD_GLOO_RENDEZVOUS_ADDR) if not rendezvous_addr: return rendezvous_port = rendezvous_port if rendezvous_port is not None else \ int(os.environ.get(HOROVOD_GLOO_RENDEZVOUS_PORT)) nic = nic or os.environ.get(HOROVOD_GLOO_IFACE) hostname = hostname or os.environ.get(HOROVOD_HOSTNAME) local_rank = local_rank if local_rank is not None else \ int(os.environ.get(HOROVOD_LOCAL_RANK)) secret_key = secret.make_secret_key() self._service = WorkerNotificationService(secret_key, nic, self) value = (self._service.addresses(), secret_key) put_data_into_kvstore(rendezvous_addr, rendezvous_port, PUT_WORKER_ADDRESSES, self._create_id(hostname, local_rank), value)
def _run_elastic(args): # construct host discovery component if args.host_discovery_script: discover_hosts = discovery.HostDiscoveryScript( args.host_discovery_script, args.slots) elif args.hosts: _, available_host_slots = hosts.parse_hosts_and_slots(args.hosts) if len(available_host_slots) < 2: raise ValueError( 'Cannot run in fault tolerance mode with fewer than 2 hosts.') discover_hosts = discovery.FixedHosts(available_host_slots) else: raise ValueError( 'One of --host-discovery-script, --hosts, or --hostnames must be provided' ) # horovodrun has to finish all the checks before this timeout runs out. if args.start_timeout: start_timeout = args.start_timeout else: # Lookup default timeout from the environment variable. start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30')) tmout = timeout.Timeout(start_timeout, message='Timed out waiting for {activity}. Please ' 'check connectivity between servers. You ' 'may need to increase the --start-timeout ' 'parameter if you have too many servers.') settings = elastic_settings.ElasticSettings( discovery=discover_hosts, min_num_proc=args.min_num_proc or args.num_proc, max_num_proc=args.max_num_proc, elastic_timeout=args.elastic_timeout, reset_limit=args.reset_limit, cooldown_range=args.cooldown_range, num_proc=args.num_proc, verbose=2 if args.verbose else 0, ssh_port=args.ssh_port, ssh_identity_file=args.ssh_identity_file, extra_mpi_args=args.mpi_args, key=secret.make_secret_key(), start_timeout=tmout, output_filename=args.output_filename, run_func_mode=args.run_func is not None, nics=args.nics, prefix_output_with_timestamp=args.prefix_output_with_timestamp) if not gloo_built(verbose=(settings.verbose >= 2)): raise ValueError( 'Gloo support is required to use elastic training, but has not been built. Ensure CMake is ' 'installed and reinstall Horovod with HOROVOD_WITH_GLOO=1 to debug the build error.' ) env = os.environ.copy() config_parser.set_env_from_args(env, args) executable = args.executable or sys.executable return gloo_run_elastic(settings, env, args.run_func if args.run_func else args.command, executable)
def __post_init__(self): if self.ssh_str and not os.path.exists(self.ssh_identity_file): with open(self.ssh_identity_file, "w") as f: os.chmod(self.ssh_identity_file, 0o600) f.write(self.ssh_str) if self.key is None: self.key = secret.make_secret_key()
def test_mpi_run_full(self): if not mpi_available(): self.skipTest("MPI is not available") cmd = ['cmd', 'arg1', 'arg2'] nics = ['eth0', 'eth1'] env = {'env1': 'val1', 'env2': 'val2'} stdout = '<stdout>' stderr = '<stderr>' tmout = timeout.Timeout(5, message='Timed out waiting for something.') settings = hvd_settings.Settings( verbose=0, ssh_port=1022, extra_mpi_args='>mpi-extra args go here<', binding_args='>binding args go here<', key=secret.make_secret_key(), start_timeout=tmout, num_proc=1, hosts='localhost:1', output_filename='>output filename goes here<', run_func_mode=True ) def mpi_impl_flags(tcp, env=None): return ["--mock-mpi-impl-flags"], [] with mock.patch("horovod.runner.mpi_run._get_mpi_implementation_flags", side_effect=mpi_impl_flags) as impl: with mock.patch("horovod.runner.mpi_run.safe_shell_exec.execute", return_value=0) as execute: mpi_run(settings, nics, env, cmd, stdout=stdout, stderr=stderr) # assert call on _get_mpi_implementation_flags impl.assert_called_once_with(None, env=env) # call the mocked _get_mpi_implementation_flags method ourselves mpi_flags, _ = horovod.runner.mpi_run._get_mpi_implementation_flags(False) self.assertIsNotNone(mpi_flags) expected_command = ('mpirun ' '--allow-run-as-root --tag-output ' '-np 1 -H {hosts} ' '>binding args go here< ' '{mpi_flags} ' '-mca plm_rsh_args "-p 1022" ' '-mca btl_tcp_if_include eth0,eth1 -x NCCL_SOCKET_IFNAME=eth0,eth1 ' '--output-filename >output filename goes here< ' '-x env1 -x env2 ' '>mpi-extra args go here< ' 'cmd arg1 arg2').format(hosts=settings.hosts, mpi_flags=' '.join(mpi_flags)) # remove PYTHONPATH from execute's env # we cannot know the exact value of that env variable # we test right handling of PYTHONPATH in test_mpi_run_*pythonpath* below self.assertIn('env', execute.call_args.kwargs) if 'PYTHONPATH' in execute.call_args.kwargs['env']: execute.call_args.kwargs['env'].pop('PYTHONPATH') expected_env = {'env1': 'val1', 'env2': 'val2', 'PATH': os.environ.get('PATH')} execute.assert_called_once_with(expected_command, env=expected_env, stdout=stdout, stderr=stderr)
def test_exit_code(self): """test non-zero exit code""" key = secret.make_secret_key() service_name = 'test-service' service = BasicTaskService(service_name, key, nics=None, verbose=2) client = BasicTaskClient(service_name, service.addresses(), key, verbose=2, attempts=1) client.run_command('false', {}) res = client.wait_for_command_exit_code() self.assertEqual(1, res)
def spark_task_service(index, key=None, nics=None, match_intf=False, minimum_command_lifetime_s=0, verbose=2): key = key or secret.make_secret_key() task = SparkTaskService(index, key, nics, minimum_command_lifetime_s, verbose) client = SparkTaskClient(index, task.addresses(), key, verbose, match_intf) try: yield task, client, key finally: task.shutdown()
def create_settings(min_np: int = 1, max_np: int = None, reset_limit: int = None, elastic_timeout: int = 600, timeout_s: int = 30, ssh_identity_file: str = None, nics: str = None, **kwargs): """Returns a Settings object for ElasticRayExecutor. Note that the `discovery` property will be set at runtime. Args: min_np (int): Minimum number of processes running for training to continue. If number of available processes dips below this threshold, then training will wait for more instances to become available. max_np (int): Maximum number of training processes, beyond which no additional processes will be created. If not specified, then will be unbounded. reset_limit (int): Maximum number of times that the training job can scale up or down the number of workers after which the job is terminated. elastic_timeout (int): Timeout for elastic initialisation after re-scaling the cluster. The default value is 600 seconds. Alternatively, the environment variable HOROVOD_ELASTIC_TIMEOUT can also be used.' timeout_s (int): Horovod performs all the checks and starts the processes before the specified timeout. The default value is 30 seconds. ssh_identity_file (str): File on the driver from which the identity (private key) is read. nics (set): Network interfaces that can be used for communication. """ start_timeout = timeout.Timeout( timeout_s, message="Timed out waiting for {activity}. Please " "check connectivity between servers. You " "may need to increase the --start-timeout " "parameter if you have too many servers.") ssh_identity_file = ssh_identity_file or os.path.expanduser( "~/ray_bootstrap_key.pem") settings = ElasticSettings( discovery=None, min_np=min_np, max_np=max_np, elastic_timeout=elastic_timeout, reset_limit=reset_limit, num_proc=min_np, ssh_identity_file=ssh_identity_file, nics=nics, start_timeout=start_timeout, key=secret.make_secret_key() if secret else None, **kwargs) return settings
def spark_driver_service(num_proc, initial_np=None, fn=fn, args=(), kwargs={}, key=None, nics=None, verbose=2): initial_np = initial_np or num_proc key = key or secret.make_secret_key() driver = SparkDriverService(initial_np, num_proc, fn, args, kwargs, key, nics) client = SparkDriverClient(driver.addresses(), key, verbose) try: yield driver, client, key finally: driver.shutdown()
def __post_init__(self): if Coordinator is None: raise ValueError( "`horovod[ray]` is not installed. " "Please install 'horovod[ray]' to use this backend.") if self.ssh_str and not os.path.exists(self.ssh_identity_file): with open(self.ssh_identity_file, "w") as f: os.chmod(self.ssh_identity_file, 0o600) f.write(self.ssh_str) if self.key is None: self.key = secret.make_secret_key()
def test_register_dispatcher_duplicate(self): key = secret.make_secret_key() service = ComputeService(2, 1, key, nics=None) try: client = ComputeClient(service.addresses(), key, verbose=2) client.register_dispatcher(0, 'grpc://localhost:10000') with self.assertRaisesRegex( ValueError, 'Dispatcher with id 0 has already been registered under ' 'different address grpc://localhost:10000: grpc://localhost:10001' ): client.register_dispatcher(0, 'grpc://localhost:10001') finally: service.shutdown()
def test_register_dispatcher_worker_timeout(self): key = secret.make_secret_key() service = ComputeService(1, 1, key, nics=None) try: client = ComputeClient(service.addresses(), key, verbose=2) with self.assertRaisesRegex( TimeoutException, expected_regex= 'Timed out waiting for workers for dispatcher 0 to register. ' 'Try to find out what takes the workers so long ' 'to register or increase timeout. Timeout after 0.1 seconds.' ): client.wait_for_dispatcher_worker_registration(0, timeout=0.1) finally: service.shutdown()
def test_concurrent_requests_basic(self): sleep = 2.0 key = secret.make_secret_key() service = TestSleepService(key, duration=sleep) client = TestSleepClient(service.addresses(), key, attempts=1) start = time.time() threads = list([in_thread(client.sleep, daemon=False) for _ in range(1)]) for thread in threads: thread.join(sleep + 1.0) self.assertFalse(thread.is_alive(), 'thread should have terminated by now') duration = time.time() - start print('concurrent requests completed in {} seconds'.format(duration)) self.assertGreaterEqual(duration, sleep, 'sleep requests should have been completed') self.assertLess(duration, sleep + 1.0, 'sleep requests should have been concurrent')
def test_stream(self): sleep = 2.0 key = secret.make_secret_key() service = TestStreamService(key, duration=sleep) try: client = TestStreamClient(service.addresses(), key, attempts=1) start = time.time() stream = io.StringIO() client.sleep(stream) duration = time.time() - start self.assertEqual(f'slept {sleep}', stream.getvalue()) self.assertGreaterEqual(duration, 2) finally: service.shutdown()
def test_register_dispatcher_worker_replay(self): key = secret.make_secret_key() service = ComputeService(1, 2, key, nics=None) try: client = ComputeClient(service.addresses(), key, verbose=2) client.register_dispatcher(0, 'grpc://localhost:10000') client.register_worker_for_dispatcher(0, 0) # register the same worker again should not complete the registration client.register_worker_for_dispatcher(0, 0) with self.assertRaises(TimeoutException): client.wait_for_dispatcher_worker_registration(0, timeout=2) # registering the second dispatcher completes registration client.register_worker_for_dispatcher(0, 1) client.wait_for_dispatcher_worker_registration(0, timeout=2) finally: service.shutdown()
class MiniSettings: """Minimal settings necessary for Ray to work. Can be replaced with a proper Horovod Settings object. """ nics: set = None verbose: int = 1 key: str = secret.make_secret_key() if secret else None ssh_port: int = None ssh_identity_file: str = None timeout_s: int = 300 @property def start_timeout(self): return timeout.Timeout( self.timeout_s, message="Timed out waiting for {activity}. Please " "check connectivity between servers. You " "may need to increase the --start-timeout " "parameter if you have too many servers.")
def main(dispatchers: int, dispatcher_side: str, configfile: str, timeout: int): hvd.init() rank, size = hvd.rank(), hvd.size() if size % dispatchers: raise ValueError( f'Number of processes ({size}) must be a multiple of number of dispatchers ({dispatchers}).' ) workers_per_dispatcher = size // dispatchers # start the compute service on rank 0 compute = None try: compute_config = None if rank == 0: key = secret.make_secret_key() compute = ComputeService(dispatchers, workers_per_dispatcher, key=key) compute_config = TfDataServiceConfig( dispatchers=dispatchers, workers_per_dispatcher=workers_per_dispatcher, dispatcher_side=dispatcher_side, addresses=compute.addresses(), key=key, timeout=timeout) compute_config.write(configfile) # broadcast this config to all ranks via CPU ops with tf.device(f'/cpu:0'): compute_config = hvd.broadcast_object(compute_config, name='TfDataServiceConfig') # start all compute workers compute_worker_fn(compute_config) finally: if compute is not None: compute.shutdown()
def test_good_path(self): for dispatchers_num, workers_per_dispatcher in [(1, 1), (1, 2), (1, 4), (2, 1), (2, 2), (2, 4), (32, 16), (1, 512)]: with self.subTest(dispatchers=dispatchers_num, workers_per_dispatcher=workers_per_dispatcher): key = secret.make_secret_key() service = ComputeService(dispatchers_num, workers_per_dispatcher, key, nics=None) try: client = ComputeClient(service.addresses(), key, verbose=2) # create thread waiting for shutdown shutdown = Queue() shutdown_thread = self.wait_for_shutdown(client, shutdown) # dispatcher registration # start threads that wait for dispatchers threads = [] dispatchers = Queue() for id in range(dispatchers_num): threads.append( self.wait_for_dispatcher(client, id, dispatchers)) # register dispatchers for id in range(dispatchers_num): client.register_dispatcher( id, f'grpc://localhost:{10000+id}') # check threads terminate for thread in threads: thread.join(10) self.assertFalse( thread.is_alive(), msg= "threads waiting for dispatchers did not terminate" ) # check reported dispatcher addresses self.assertEqual([(id, f'grpc://localhost:{10000+id}') for id in range(dispatchers_num)], sorted(self.get_all(dispatchers))) # worker registration # start threads to wait for dispatcher worker registration threads = [] dispatchers = Queue() for id in range(dispatchers_num): threads.append( self.wait_for_dispatcher_workers( client, id, dispatchers)) # register dispatcher workers for id in range(dispatchers_num * workers_per_dispatcher): client.register_worker_for_dispatcher( dispatcher_id=id // workers_per_dispatcher, worker_id=id) # check threads terminate for thread in threads: thread.join(10) self.assertFalse( thread.is_alive(), msg= "threads waiting for dispatchers' workers did not terminate" ) # check reported dispatcher success self.assertEqual(sorted(range(dispatchers_num)), sorted(self.get_all(dispatchers))) # shutdown and wait for shutdown self.assertTrue( shutdown_thread.is_alive(), msg="thread waiting for shutdown, terminated early") client.shutdown() shutdown_thread.join(10) self.assertFalse( shutdown_thread.is_alive(), msg="thread waiting for shutdown did not terminate") self.assertEqual([True], list(self.get_all(shutdown))) finally: service.shutdown()
parser.add_argument("--dispatcher-side", required=False, default='compute', type=str, help=f"Where do the dispatcher run? On 'compute' side or 'training' side.", dest="dispatcher_side") parsed_args = parser.parse_args() spark = SparkSession.builder.getOrCreate() spark_context = spark.sparkContext workers = spark_context.defaultParallelism if workers % parsed_args.dispatchers: raise ValueError(f'Number of processes ({workers}) must be ' f'a multiple of number of dispatchers ({parsed_args.dispatchers}).') workers_per_dispatcher = workers // parsed_args.dispatchers key = secret.make_secret_key() compute = ComputeService(parsed_args.dispatchers, workers_per_dispatcher, key=key) compute_config = TfDataServiceConfig( dispatchers=parsed_args.dispatchers, workers_per_dispatcher=workers_per_dispatcher, dispatcher_side=parsed_args.dispatcher_side, addresses=compute.addresses(), key=key ) compute_config.write(parsed_args.configfile) def _exit_gracefully(): logging.info('Spark driver receiving SIGTERM. Exiting gracefully') spark_context.stop()
def run_elastic( fn, args=(), kwargs={}, num_proc=None, min_num_proc=None, max_num_proc=None, start_timeout=None, elastic_timeout=None, reset_limit=None, env=None, stdout=None, stderr=None, verbose=1, nics=None, prefix_output_with_timestamp=False, # np is deprecated, use min_num_proc instead min_np=None, # max_num_proc is deprecated, use max_num_proc instead max_np=None): """ Runs Elastic Horovod on Spark. Runs `num_proc` processes executing `fn` using the same amount of Spark tasks. Args: fn: Function to run. args: Arguments to pass to `fn`. kwargs: Keyword arguments to pass to `fn`. num_proc: Number of Horovod processes. Defaults to `spark.default.parallelism`. min_num_proc: Minimum number of processes running for training to continue. If number of available processes dips below this threshold, then training will wait for more instances to become available. max_num_proc: Maximum number of training processes, beyond which no additional processes will be created. If not specified, then will be unbounded. start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds. If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value. If it is not set as well, defaults to 600 seconds. elastic_timeout: Timeout for elastic initialisation after re-scaling the cluster. If not set, falls back to `HOROVOD_ELASTIC_TIMEOUT` environment variable value. If it is not set as well, defaults to 600 seconds. reset_limit: Maximum number of resets after which the job is terminated. env: Environment dictionary to use in Horovod run. Defaults to `os.environ`. stdout: Horovod stdout is redirected to this stream. stderr: Horovod stderr is redirected to this stream. verbose: Debug output verbosity (0-2). Defaults to 1. nics: List of NICs for tcp network communication. prefix_output_with_timestamp: shows timestamp in stdout/stderr forwarding on the driver Returns: List of results returned by running `fn` on each rank. """ if min_np is not None: min_num_proc = min_np warnings.warn('min_np is deprecated, use min_num_proc instead', DeprecationWarning) if max_np is not None: max_num_proc = max_np warnings.warn('max_np is deprecated, use max_num_proc instead', DeprecationWarning) if not gloo_built(verbose=(verbose >= 2)): raise ValueError( 'Gloo support is required to use elastic training, but has not been built. Ensure CMake is ' 'installed and reinstall Horovod with HOROVOD_WITH_GLOO=1 to debug the build error.' ) spark_context = pyspark.SparkContext._active_spark_context if spark_context is None: raise Exception('Could not find an active SparkContext, are you ' 'running in a PySpark session?') if start_timeout is None: # Lookup default timeout from the environment variable. start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600')) # nics needs to be a set if nics and not isinstance(nics, set): nics = set(nics) if num_proc is None: # TODO: #2023 try spark.dynamicAllocation.initialExecutors num_proc = spark_context.defaultParallelism if verbose >= 1: logging.info( 'Running %d processes (inferred from spark.default.parallelism)...', num_proc) else: if verbose >= 1: logging.info('Running %d processes...', num_proc) if min_num_proc is None: # TODO: #2023 try spark.dynamicAllocation.minExecutors min_num_proc = num_proc if max_num_proc is None: # TODO: #2023 try spark.dynamicAllocation.maxExecutors max_num_proc = num_proc # start Spark driver service and launch settings.num_proc Spark tasks key = secret.make_secret_key() spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id() driver = driver_service.SparkDriverService(num_proc, max_num_proc, fn, args, kwargs, key, nics) discovery = host_discovery.SparkDriverHostDiscovery(driver) tmout = timeout.Timeout( start_timeout, message='Timed out waiting for {activity}. Please check that you have ' 'enough resources to run all Horovod processes. Each Horovod ' 'process runs in a Spark task. You may need to increase the ' 'start_timeout parameter to a larger value if your Spark resources ' 'are allocated on-demand.') settings = hvd_elastic_settings.ElasticSettings( discovery=discovery, min_num_proc=min_num_proc, max_num_proc=max_num_proc, elastic_timeout=elastic_timeout, reset_limit=reset_limit, num_proc=num_proc, verbose=verbose, key=key, start_timeout=tmout, nics=nics, run_func_mode=True, prefix_output_with_timestamp=prefix_output_with_timestamp) result_queue = queue.Queue(1) # launch settings.num_proc / settings.max_num_proc Spark tasks spark_thread = _make_spark_thread(spark_context, spark_job_group, driver, result_queue, settings, use_gloo=True, is_elastic=True) try: # Register task addresses of initial num_proc tasks _register_task_addresses(driver, settings) # Run the job gloo_run_elastic(settings, driver, env, stdout, stderr) except: # Terminate Spark job. spark_context.cancelJobGroup(spark_job_group) # Re-raise exception. raise finally: spark_thread.join() driver.shutdown() # Make sure Spark Job did not fail. driver.check_for_spark_job_failure() # get ranks from driver indices_in_rank_order = _get_indices_in_rank_order(driver) # If there's no exception, execution results are in this queue. results = result_queue.get_nowait() return [results[index] for index in indices_in_rank_order]
def _run_static(args): nics_set = set(args.nics.split(',')) if args.nics else None # horovodrun has to finish all the checks before this timeout runs out. if args.start_timeout: start_timeout = args.start_timeout else: # Lookup default timeout from the environment variable. start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30')) tmout = timeout.Timeout(start_timeout, message='Timed out waiting for {activity}. Please ' 'check connectivity between servers. You ' 'may need to increase the --start-timeout ' 'parameter if you have too many servers.') settings = hvd_settings.Settings(verbose=2 if args.verbose else 0, ssh_port=args.ssh_port, ssh_identity_file=args.ssh_identity_file, extra_mpi_args=args.mpi_args, tcp_flag=args.tcp_flag, binding_args=args.binding_args, key=secret.make_secret_key(), start_timeout=tmout, num_proc=args.np, hosts=args.hosts, output_filename=args.output_filename, run_func_mode=args.run_func is not None, nics=nics_set) # This cache stores the results of checks performed by horovod # during the initialization step. It can be disabled by setting # --disable-cache flag. fn_cache = None if not args.disable_cache: params = '' if args.np: params += str(args.np) + ' ' if args.hosts: params += str(args.hosts) + ' ' if args.ssh_port: params += str(args.ssh_port) if args.ssh_identity_file: params += args.ssh_identity_file parameters_hash = hashlib.md5(params.encode('utf-8')).hexdigest() fn_cache = cache.Cache(CACHE_FOLDER, CACHE_STALENESS_THRESHOLD_MINUTES, parameters_hash) all_host_names, _ = hosts.parse_hosts_and_slots(args.hosts) if settings.verbose >= 2: print('Filtering local host names.') remote_host_names = network.filter_local_addresses(all_host_names) if settings.verbose >= 2: print('Remote host found: ' + ' '.join(remote_host_names)) if len(remote_host_names) > 0: if settings.verbose >= 2: print('Checking ssh on all remote hosts.') # Check if we can ssh into all remote hosts successfully. if not _check_all_hosts_ssh_successful(remote_host_names, args.ssh_port, args.ssh_identity_file, fn_cache=fn_cache): raise RuntimeError('could not connect to some hosts via ssh') if settings.verbose >= 2: print('SSH was successful into all the remote hosts.') nics = driver_service.get_common_interfaces(settings, all_host_names, remote_host_names, fn_cache) if args.run_func: # get the driver IPv4 address driver_ip = network.get_driver_ip(nics) run_func_server = KVStoreServer(verbose=settings.verbose) run_func_server_port = run_func_server.start_server() put_data_into_kvstore(driver_ip, run_func_server_port, 'runfunc', 'func', args.run_func) command = [ sys.executable, '-m', 'horovod.runner.run_task', str(driver_ip), str(run_func_server_port) ] try: _launch_job(args, settings, nics, command) results = [None] * args.np # TODO: make it parallel to improve performance for i in range(args.np): results[i] = read_data_from_kvstore(driver_ip, run_func_server_port, 'runfunc_result', str(i)) return results finally: run_func_server.shutdown_server() else: command = args.command _launch_job(args, settings, nics, command) return None
def run(fn, args=(), kwargs={}, num_proc=None, start_timeout=None, use_mpi=None, use_gloo=None, extra_mpi_args=None, env=None, stdout=None, stderr=None, verbose=1, nics=None, prefix_output_with_timestamp=False, executable=None): """ Runs Horovod on Spark. Runs `num_proc` processes executing `fn` using the same amount of Spark tasks. Args: fn: Function to run. args: Arguments to pass to `fn`. kwargs: Keyword arguments to pass to `fn`. num_proc: Number of Horovod processes. Defaults to `spark.default.parallelism`. start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds. If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value. If it is not set as well, defaults to 600 seconds. extra_mpi_args: Extra arguments for mpi_run. Defaults to no extra args. env: Environment dictionary to use in Horovod run. stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout when used with MPI. stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr when used with MPI. verbose: Debug output verbosity (0-2). Defaults to 1. nics: List of NICs for tcp network communication. prefix_output_with_timestamp: shows timestamp in stdout/stderr forwarding on the driver executable: Optional executable to run when launching the workers. Defaults to `sys.executable`. Returns: List of results returned by running `fn` on each rank. """ if start_timeout is None: # Lookup default timeout from the environment variable. start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600')) # nics needs to be a set if nics and not isinstance(nics, set): nics = set(nics) tmout = timeout.Timeout( start_timeout, message='Timed out waiting for {activity}. Please check that you have ' 'enough resources to run all Horovod processes. Each Horovod ' 'process runs in a Spark task. You may need to increase the ' 'start_timeout parameter to a larger value if your Spark resources ' 'are allocated on-demand.') settings = hvd_settings.Settings( verbose=verbose, extra_mpi_args=extra_mpi_args, key=secret.make_secret_key(), start_timeout=tmout, nics=nics, run_func_mode=True, prefix_output_with_timestamp=prefix_output_with_timestamp) spark_context = pyspark.SparkContext._active_spark_context if spark_context is None: raise Exception('Could not find an active SparkContext, are you ' 'running in a PySpark session?') if num_proc is None: num_proc = spark_context.defaultParallelism if settings.verbose >= 1: logging.info( 'Running %d processes (inferred from spark.default.parallelism)...', num_proc) else: if settings.verbose >= 1: logging.info('Running %d processes...', num_proc) settings.num_proc = num_proc result_queue = queue.Queue(1) # start Spark driver service and launch settings.num_proc Spark tasks spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id() driver = driver_service.SparkDriverService(settings.num_proc, settings.num_proc, fn, args, kwargs, settings.key, settings.nics) gloo_is_used = is_gloo_used(use_gloo=use_gloo, use_mpi=use_mpi, use_jsrun=False) spark_thread = _make_spark_thread(spark_context, spark_job_group, driver, result_queue, settings, use_gloo=gloo_is_used, is_elastic=False) try: # wait for all tasks to register, notify them and initiate task-to-task address registration _notify_and_register_task_addresses(driver, settings) # Determine the index grouping based on host hashes. # Barrel shift until index 0 is in the first host. host_hashes = list(driver.task_host_hash_indices().keys()) host_hashes.sort() while 0 not in driver.task_host_hash_indices()[host_hashes[0]]: host_hashes = host_hashes[1:] + host_hashes[:1] settings.hosts = ','.join( '%s:%d' % (host_hash, len(driver.task_host_hash_indices()[host_hash])) for host_hash in host_hashes) # Run the job _launch_job(use_mpi, use_gloo, settings, driver, env, stdout, stderr, executable) except: # Terminate Spark job. spark_context.cancelJobGroup(spark_job_group) # Re-raise exception. raise finally: spark_thread.join() driver.shutdown() # Make sure Spark Job did not fail. driver.check_for_spark_job_failure() # get ranks from driver indices_in_rank_order = _get_indices_in_rank_order(driver) # If there's no exception, execution results are in this queue. results = result_queue.get_nowait() return [results[index] for index in indices_in_rank_order]