def test_rank_and_size(self): """Tests two hosts, two slots each in standard happy path.""" slots = {'host-1': 2, 'host-2': 2} discovery = FixedHosts(slots) driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4) driver.wait_for_available_slots(min_np=2) rank_results = {} def exec_command(slot_info, events): driver.record_ready(slot_info.hostname, slot_info.local_rank) updated_slot_info = driver.get_slot_info(slot_info.hostname, slot_info.local_rank) rank_results[slot_info.rank] = (slot_info, updated_slot_info) return 0, time.time() driver.start(np=2, create_worker_fn=exec_command) res = driver.get_results().worker_results driver.stop() assert len(res) == 4 for name, (exit_code, timestamp) in res.items(): assert exit_code == 0, name assert len(rank_results) == 4 for rank, (slot_info, updated_slot_info) in rank_results.items(): assert slot_info.to_response_string( ) == updated_slot_info.to_response_string(), rank
def test_shutdown_on_success(self): """Tests that shutdown event is triggered when one worker succeeds but the others are still working.""" slots = {'host-1': 2, 'host-2': 2} discovery = FixedHosts(slots) driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4) driver.wait_for_available_slots(min_np=2) def exec_command(slot_info, events): if slot_info.rank == 0: return 0, time.time() driver.record_ready(slot_info.hostname, slot_info.local_rank) wait_for_one(events) return 1, time.time() driver.start(np=2, create_worker_fn=exec_command) res = driver.get_results().worker_results driver.stop() assert len(res) == 4 exit_code_sum = 0 for name, (exit_code, timestamp) in res.items(): exit_code_sum += exit_code assert exit_code_sum == 3
def test_rank_and_size_with_host_failure(self): """Tests two hosts, two slots each with second host failing before rendezvous completes.""" slots = {'host-1': 2, 'host-2': 2} discovery = FixedHosts(slots) driver = ElasticDriver(mock.Mock(), discovery, min_num_proc=2, max_num_proc=4) driver.wait_for_available_slots(min_num_proc=2) rank_results = {} def exec_command(slot_info, events): if slot_info.hostname == 'host-2': return 1, time.time() driver.record_ready(slot_info.hostname, slot_info.local_rank) updated_slot_info = driver.get_slot_info(slot_info.hostname, slot_info.local_rank) rank_results[slot_info.rank] = (slot_info, updated_slot_info) return 0, time.time() driver.start(num_proc=2, create_worker_fn=exec_command) res = driver.get_results().worker_results driver.stop() assert len(res) == 2 for name, (exit_code, timestamp) in res.items(): assert exit_code == 0, name assert len(rank_results) == 2 for rank, (slot_info, updated_slot_info) in rank_results.items(): assert updated_slot_info.size == 2, rank assert updated_slot_info.rank == slot_info.rank % 2, rank assert updated_slot_info.local_size == slot_info.local_size, rank assert updated_slot_info.local_rank == slot_info.local_rank, rank assert updated_slot_info.cross_size == 1, rank assert updated_slot_info.cross_rank == 0, rank
def test_rank_and_size_with_host_added(self): """Tests training starts with one host two slots, then a second host is added.""" slots = {'host-1': 2} discovery = FixedHosts(slots) def add_host(): slots = {'host-1': 2, 'host-2': 2} discovery.set(slots) driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4) driver.wait_for_available_slots(min_np=2) rank_results = {} def exec_command(slot_info, events): driver.record_ready(slot_info.hostname, slot_info.local_rank) if slot_info.hostname == 'host-1': if slot_info.rank == 0: add_host() driver.wait_for_available_slots(4) driver.record_ready(slot_info.hostname, slot_info.local_rank) driver.record_ready(slot_info.hostname, slot_info.local_rank) updated_slot_info = driver.get_slot_info(slot_info.hostname, slot_info.local_rank) rank_results[slot_info.rank] = (slot_info, updated_slot_info) return 0, time.time() driver.start(np=2, create_worker_fn=exec_command) res = driver.get_results().worker_results driver.stop() assert len(res) == 4 for name, (exit_code, timestamp) in res.items(): assert exit_code == 0, name assert len(rank_results) == 4 for rank, (slot_info, updated_slot_info) in rank_results.items(): assert updated_slot_info.size == 4, rank assert updated_slot_info.rank == slot_info.rank, rank assert updated_slot_info.local_size == slot_info.local_size, rank assert updated_slot_info.local_rank == slot_info.local_rank, rank assert updated_slot_info.cross_size == 2, rank assert updated_slot_info.cross_rank == slot_info.cross_rank, rank
def test_all_workers_fail(self): """Tests that training fails when all workers fail.""" slots = {'host-1': 2, 'host-2': 2} discovery = FixedHosts(slots) driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4) driver.wait_for_available_slots(min_np=2) def exec_command(slot_info, events): driver.record_ready(slot_info.hostname, slot_info.local_rank) return 1, time.time() driver.start(np=2, create_worker_fn=exec_command) res = driver.get_results().worker_results driver.stop() assert len(res) == 4 for name, (exit_code, timestamp) in res.items(): assert exit_code == 1, name
def test_worker_notification_manager(self): """Tests that host add events are sent to the worker notification service and consumed.""" slots = {'host-1': 2} discovery = FixedHosts(slots) rendezvous = RendezvousServer() driver = ElasticDriver(rendezvous, discovery, min_np=2, max_np=4) driver.wait_for_available_slots(min_np=2) handler = create_rendezvous_handler(driver) common_intfs = network.get_local_intfs() addr = network.get_driver_ip(common_intfs) port = rendezvous.start(handler) nic = list(common_intfs)[0] rank_results = {} class NotificationReceiver: def __init__(self): self.events = [] def on_hosts_updated(self, timestamp, res): self.events.append((timestamp, res)) def add_host(): slots = {'host-1': 2, 'host-2': 2} discovery.set(slots) def remove_host(): slots = {'host-2': 2} discovery.set(slots) def exec_command(slot_info, events): manager = WorkerNotificationManager() manager.init(rendezvous_addr=addr, rendezvous_port=port, nic=nic, hostname=slot_info.hostname, local_rank=slot_info.local_rank) notification_receiver = NotificationReceiver() manager.register_listener(notification_receiver) driver.record_ready(slot_info.hostname, slot_info.local_rank) if slot_info.rank == 0: add_host() driver.wait_for_available_slots(4) if slot_info.rank == 0: remove_host() # Busy wait for the number of available slots to decrease while driver._host_manager.current_hosts.count_available_slots( ) > 2: time.sleep(0.01) rank_results[slot_info.rank] = notification_receiver.events return 0, time.time() driver.start(np=2, create_worker_fn=exec_command) res = driver.get_results().worker_results driver.stop() assert len(res) == 2 for name, (exit_code, timestamp) in res.items(): assert exit_code == 0, name assert len(rank_results) == 2 for rank, events in rank_results.items(): expected = 2 if rank == 0 else 0 assert len(events) == expected, rank if rank == 0: # First update is an add assert events[0][1] == HostUpdateResult.added # Second update is a removal assert events[1][1] == HostUpdateResult.removed rendezvous.stop()