def test_blacklist_host(self): """Tests the hosts are blacklisted, resulting in changes to the available hosts.""" mock_discovery = mock.Mock() mock_discovery.find_available_hosts_and_slots.return_value = { 'a': 2, 'b': 2 } host_manager = HostManager(mock_discovery) host_manager.update_available_hosts() # Sanity check before we blacklist current_hosts = host_manager.current_hosts assert current_hosts.available_hosts == {'a', 'b'} assert current_hosts.count_available_slots() == 4 # Now blacklist, our existing object should not change (immutable) host_manager.blacklist('a') assert current_hosts.available_hosts == {'a', 'b'} assert current_hosts.count_available_slots() == 4 # Check the new object, make sure we've blacklisted the host current_hosts = host_manager.current_hosts assert current_hosts.available_hosts == {'b'} assert current_hosts.count_available_slots() == 2
def test_update_available_hosts(self): """Tests that the current hosts object is immutable, while fetching the latest is correctly updated.""" mock_discovery = mock.Mock() mock_discovery.find_available_hosts_and_slots.side_effect = [ {'a': 2}, {'a': 2, 'b': 2}, {'b': 2}, {'b': 1, 'c': 1}, {'b': 1, 'c': 1} ] host_manager = HostManager(mock_discovery) # Should be empty initially current_hosts = host_manager.current_hosts assert current_hosts.available_hosts == set() assert current_hosts.count_available_slots() == 0 # From empty to {'a': 2}, it is an add update assert host_manager.update_available_hosts() == HostUpdateResult.added # First, check that nothing changed with our existing object, which is immutable assert current_hosts.available_hosts == set() assert current_hosts.count_available_slots() == 0 # Now verify that the new object has the correct sets current_hosts = host_manager.current_hosts assert current_hosts.available_hosts == {'a'} assert current_hosts.count_available_slots() == 2 # Now check again # It is an increase update assert host_manager.update_available_hosts() == HostUpdateResult.added current_hosts = host_manager.current_hosts assert current_hosts.available_hosts == {'a', 'b'} assert current_hosts.count_available_slots() == 4 # And again # It is a removal update assert host_manager.update_available_hosts() == HostUpdateResult.removed current_hosts = host_manager.current_hosts assert current_hosts.available_hosts == {'b'} assert current_hosts.count_available_slots() == 2 # Try one more time # It is a mix update assert host_manager.update_available_hosts() == HostUpdateResult.mixed current_hosts = host_manager.current_hosts assert current_hosts.available_hosts == {'b', 'c'} assert current_hosts.count_available_slots() == 2 # Finally # No change assert host_manager.update_available_hosts() == HostUpdateResult.no_update current_hosts = host_manager.current_hosts assert current_hosts.available_hosts == {'b', 'c'} assert current_hosts.count_available_slots() == 2
class ElasticDriver(object): def __init__(self, rendezvous, discovery, min_np, max_np, timeout=None, reset_limit=None, verbose=0): self._rendezvous = rendezvous self._host_manager = HostManager(discovery) self._min_np = min_np self._max_np = max_np self._verbose = verbose self._host_assignments = {} self._rank_assignments = {} self._world_size = 0 self._wait_hosts_cond = threading.Condition() self._timeout = timeout or int(os.getenv('HOROVOD_ELASTIC_TIMEOUT', ELASTIC_TIMEOUT_SECS)) self._create_worker_fn = None self._worker_clients = {} self._worker_registry = WorkerStateRegistry(self, self._host_manager, reset_limit=reset_limit) self._results = ResultsRecorder() self._shutdown = threading.Event() self._discovery_thread = threading.Thread(target=self._discover_hosts) self._discovery_thread.daemon = True self._discovery_thread.start() def start(self, np, create_worker_fn): self._create_worker_fn = create_worker_fn self._activate_workers(np) def resume(self): self._activate_workers(self._min_np) def stop(self, error_message=None): self._results.set_error_message(error_message) self._shutdown.set() self._rendezvous.stop() self._discovery_thread.join() def finished(self): return self._shutdown.is_set() def get_results(self): return self._results.get_results() def register_worker_server(self, host, slot, addresses, secret_key): self._worker_clients[(host, slot)] = WorkerNotificationClient( addresses, secret_key, self._verbose) def get_worker_client(self, slot_info): return self._worker_clients.get((slot_info.hostname, slot_info.local_rank)) def record_ready(self, host, slot): self._worker_registry.record_ready(host, slot) def world_size(self): return self._world_size def local_size(self, host): return len(self._host_assignments[host]) def get_slot_info(self, host, slot): return self._host_assignments[host][slot] if self.has_rank_assignment(host, slot) \ else hosts.INVALID_SLOT_INFO def get_coordinator_info(self): return self._rank_assignments.get(0) def has_rank_assignment(self, host, slot): if self._host_manager.is_blacklisted(host): return False return host in self._host_assignments and len(self._host_assignments[host]) > slot @property def host_assignments(self): return self._host_assignments def wait_for_available_slots(self, min_np, min_hosts=1): extra_message = ' An elastic job also requires that at least two hosts ' \ 'are available to resolve compatible network interfaces. If you know which interfaces ' \ 'are compatible in your network, set `--nic` to skip this check.' if min_hosts > 1 else '' tmout = timeout.Timeout( self._timeout, message='Timed out waiting for {{activity}}. Please check that you have ' 'enough resources to run at least {min_np} Horovod processes.{extra_message}' .format(min_np=min_np, extra_message=extra_message)) self._wait_hosts_cond.acquire() try: while True: current_hosts = self._host_manager.current_hosts if current_hosts.count_available_slots() >= min_np and len(current_hosts.available_hosts) >= min_hosts: return current_hosts if self._shutdown.is_set(): raise RuntimeError('Job has been shutdown, see above error messages for details.') self._wait_hosts_cond.wait(tmout.remaining()) tmout.check_time_out_for('minimum number of slots to become available') finally: self._wait_hosts_cond.release() def _activate_workers(self, min_np): logging.info('wait for available slots: {}'.format(min_np)) current_hosts = self.wait_for_available_slots(min_np) pending_slots = self._update_host_assignments(current_hosts) self._worker_registry.reset(self.world_size()) self._start_worker_processes(pending_slots) def _discover_hosts(self): first_update = True while not self._shutdown.is_set(): self._wait_hosts_cond.acquire() try: if self._host_manager.update_available_hosts(): self._notify_workers_host_changes(self._host_manager.current_hosts) self._wait_hosts_cond.notify_all() except RuntimeError as e: if first_update: # Misconfiguration, fail the job immediately self._shutdown.set() self._wait_hosts_cond.notify_all() raise # Transient error, retry until timeout logging.warning(str(e)) finally: self._wait_hosts_cond.release() first_update = False self._shutdown.wait(DISCOVER_HOSTS_FREQUENCY_SECS) def _notify_workers_host_changes(self, current_hosts): next_host_assignments = {} if current_hosts.count_available_slots() >= self._min_np: # Assignments are required to be stable via contract next_host_assignments, _ = self._get_host_assignments(current_hosts) if next_host_assignments == self.host_assignments: # Skip notifying workers when host changes would not result in changes of host assignments logging.debug('no host assignment changes, skipping notifications') return coordinator_slot_info = self.get_coordinator_info() if not coordinator_slot_info: logging.debug('no coordinator info, skipping notifications') return coordinator_client = self.get_worker_client(coordinator_slot_info) if not coordinator_client: logging.debug('no coordinator client, skipping notifications') return timestamp = _epoch_time_s() try: coordinator_client.notify_hosts_updated(timestamp) except: if self._verbose >= 2: logging.exception('failed to notify {}[{}] of host updates' .format(coordinator_slot_info.hostname, coordinator_slot_info.local_rank)) def _update_host_assignments(self, current_hosts): # Determine the slots that are already filled so we do not respawn these processes active_slots = set([(host, slot_info.local_rank) for host, slots in self._host_assignments.items() for slot_info in slots]) # Adjust the host assignments to account for added / removed hosts host_assignments, host_assignments_list = self._get_host_assignments(current_hosts) if len(self._host_assignments) > 0: # Ensure that at least one previously active host is still assigned, otherwise there is no # way to sync the state to the new workers prev_hosts = self._host_assignments.keys() next_hosts = host_assignments.keys() if not prev_hosts & next_hosts: raise RuntimeError('No hosts from previous set remaining, unable to broadcast state.') self._host_assignments = host_assignments self._world_size = len(host_assignments_list) self._rendezvous.init(host_assignments_list) # Rank assignments map from world rank to slot info rank_assignments = {} for slot_info in host_assignments_list: rank_assignments[slot_info.rank] = slot_info self._rank_assignments = rank_assignments # Get the newly assigned slots that need to be started pending_slots = [slot_info for host, slots in self._host_assignments.items() for slot_info in slots if (host, slot_info.local_rank) not in active_slots] return pending_slots def _get_host_assignments(self, current_hosts): # Adjust the host assignments to account for added / removed hosts host_list = [hosts.HostInfo(host, current_hosts.get_slots(host)) for host in current_hosts.host_assignment_order] host_assignments_list = hosts.get_host_assignments(host_list, self._min_np, self._max_np) host_assignments = defaultdict(list) for slot_info in host_assignments_list: host_assignments[slot_info.hostname].append(slot_info) return host_assignments, host_assignments_list def _start_worker_processes(self, pending_slots): for slot_info in pending_slots: logging.info('start worker process: {}[{}]'.format(slot_info.hostname, slot_info.local_rank)) self._start_worker_process(slot_info) def _start_worker_process(self, slot_info): create_worker_fn = self._create_worker_fn shutdown_event = self._shutdown host_event = self._host_manager.get_host_event(slot_info.hostname) def run_worker(): res = create_worker_fn(slot_info, [shutdown_event, host_event]) exit_code, timestamp = res self._handle_worker_exit(slot_info, exit_code, timestamp) thread = threading.Thread(target=run_worker) thread.daemon = True thread.start() self._results.expect(thread) def _handle_worker_exit(self, slot_info, exit_code, timestamp): if not self.has_rank_assignment(slot_info.hostname, slot_info.local_rank): # Ignore hosts that are not assigned a rank logging.debug('host {} has been blacklisted, ignoring exit from local_rank={}' .format(slot_info.hostname, slot_info.local_rank)) return if exit_code == 0: rendezvous_id = self._worker_registry.record_success(slot_info.hostname, slot_info.local_rank) else: rendezvous_id = self._worker_registry.record_failure(slot_info.hostname, slot_info.local_rank) if self.finished() and self._worker_registry.last_rendezvous() == rendezvous_id: logging.debug('adding results for {}[{}]: ({}, {})' .format(slot_info.hostname, slot_info.local_rank, exit_code, timestamp)) name = '{}[{}]'.format(slot_info.hostname, slot_info.local_rank) self._results.add_result(name, (exit_code, timestamp))