Exemple #1
0
    def test_rank_and_size(self):
        """Tests two hosts, two slots each in standard happy path."""
        slots = {'host-1': 2, 'host-2': 2}
        discovery = FixedHosts(slots)

        driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4)
        driver.wait_for_available_slots(min_np=2)

        rank_results = {}

        def exec_command(slot_info, events):
            driver.record_ready(slot_info.hostname, slot_info.local_rank)
            updated_slot_info = driver.get_slot_info(slot_info.hostname,
                                                     slot_info.local_rank)
            rank_results[slot_info.rank] = (slot_info, updated_slot_info)
            return 0, time.time()

        driver.start(np=2, create_worker_fn=exec_command)
        res = driver.get_results().worker_results
        driver.stop()

        assert len(res) == 4
        for name, (exit_code, timestamp) in res.items():
            assert exit_code == 0, name

        assert len(rank_results) == 4
        for rank, (slot_info, updated_slot_info) in rank_results.items():
            assert slot_info.to_response_string(
            ) == updated_slot_info.to_response_string(), rank
Exemple #2
0
    def test_shutdown_on_success(self):
        """Tests that shutdown event is triggered when one worker succeeds but the others are still working."""
        slots = {'host-1': 2, 'host-2': 2}
        discovery = FixedHosts(slots)

        driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4)
        driver.wait_for_available_slots(min_np=2)

        def exec_command(slot_info, events):
            if slot_info.rank == 0:
                return 0, time.time()

            driver.record_ready(slot_info.hostname, slot_info.local_rank)
            wait_for_one(events)
            return 1, time.time()

        driver.start(np=2, create_worker_fn=exec_command)
        res = driver.get_results().worker_results
        driver.stop()

        assert len(res) == 4

        exit_code_sum = 0
        for name, (exit_code, timestamp) in res.items():
            exit_code_sum += exit_code
        assert exit_code_sum == 3
Exemple #3
0
    def test_rank_and_size_with_host_failure(self):
        """Tests two hosts, two slots each with second host failing before rendezvous completes."""
        slots = {'host-1': 2, 'host-2': 2}
        discovery = FixedHosts(slots)

        driver = ElasticDriver(mock.Mock(), discovery, min_num_proc=2, max_num_proc=4)
        driver.wait_for_available_slots(min_num_proc=2)

        rank_results = {}

        def exec_command(slot_info, events):
            if slot_info.hostname == 'host-2':
                return 1, time.time()

            driver.record_ready(slot_info.hostname, slot_info.local_rank)
            updated_slot_info = driver.get_slot_info(slot_info.hostname, slot_info.local_rank)
            rank_results[slot_info.rank] = (slot_info, updated_slot_info)
            return 0, time.time()

        driver.start(num_proc=2, create_worker_fn=exec_command)
        res = driver.get_results().worker_results
        driver.stop()

        assert len(res) == 2
        for name, (exit_code, timestamp) in res.items():
            assert exit_code == 0, name

        assert len(rank_results) == 2
        for rank, (slot_info, updated_slot_info) in rank_results.items():
            assert updated_slot_info.size == 2, rank
            assert updated_slot_info.rank == slot_info.rank % 2, rank
            assert updated_slot_info.local_size == slot_info.local_size, rank
            assert updated_slot_info.local_rank == slot_info.local_rank, rank
            assert updated_slot_info.cross_size == 1, rank
            assert updated_slot_info.cross_rank == 0, rank
Exemple #4
0
    def test_rank_and_size_with_host_added(self):
        """Tests training starts with one host two slots, then a second host is added."""
        slots = {'host-1': 2}
        discovery = FixedHosts(slots)

        def add_host():
            slots = {'host-1': 2, 'host-2': 2}
            discovery.set(slots)

        driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4)
        driver.wait_for_available_slots(min_np=2)

        rank_results = {}

        def exec_command(slot_info, events):
            driver.record_ready(slot_info.hostname, slot_info.local_rank)

            if slot_info.hostname == 'host-1':
                if slot_info.rank == 0:
                    add_host()
                driver.wait_for_available_slots(4)
                driver.record_ready(slot_info.hostname, slot_info.local_rank)

            driver.record_ready(slot_info.hostname, slot_info.local_rank)
            updated_slot_info = driver.get_slot_info(slot_info.hostname,
                                                     slot_info.local_rank)
            rank_results[slot_info.rank] = (slot_info, updated_slot_info)
            return 0, time.time()

        driver.start(np=2, create_worker_fn=exec_command)
        res = driver.get_results().worker_results
        driver.stop()

        assert len(res) == 4
        for name, (exit_code, timestamp) in res.items():
            assert exit_code == 0, name

        assert len(rank_results) == 4
        for rank, (slot_info, updated_slot_info) in rank_results.items():
            assert updated_slot_info.size == 4, rank
            assert updated_slot_info.rank == slot_info.rank, rank
            assert updated_slot_info.local_size == slot_info.local_size, rank
            assert updated_slot_info.local_rank == slot_info.local_rank, rank
            assert updated_slot_info.cross_size == 2, rank
            assert updated_slot_info.cross_rank == slot_info.cross_rank, rank
Exemple #5
0
    def test_all_workers_fail(self):
        """Tests that training fails when all workers fail."""
        slots = {'host-1': 2, 'host-2': 2}
        discovery = FixedHosts(slots)

        driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4)
        driver.wait_for_available_slots(min_np=2)

        def exec_command(slot_info, events):
            driver.record_ready(slot_info.hostname, slot_info.local_rank)
            return 1, time.time()

        driver.start(np=2, create_worker_fn=exec_command)
        res = driver.get_results().worker_results
        driver.stop()

        assert len(res) == 4
        for name, (exit_code, timestamp) in res.items():
            assert exit_code == 1, name
Exemple #6
0
    def test_worker_notification_manager(self):
        """Tests that host add events are sent to the worker notification service and consumed."""
        slots = {'host-1': 2}
        discovery = FixedHosts(slots)

        rendezvous = RendezvousServer()
        driver = ElasticDriver(rendezvous, discovery, min_np=2, max_np=4)
        driver.wait_for_available_slots(min_np=2)
        handler = create_rendezvous_handler(driver)

        common_intfs = network.get_local_intfs()
        addr = network.get_driver_ip(common_intfs)
        port = rendezvous.start(handler)
        nic = list(common_intfs)[0]

        rank_results = {}

        class NotificationReceiver:
            def __init__(self):
                self.events = []

            def on_hosts_updated(self, timestamp, res):
                self.events.append((timestamp, res))

        def add_host():
            slots = {'host-1': 2, 'host-2': 2}
            discovery.set(slots)

        def remove_host():
            slots = {'host-2': 2}
            discovery.set(slots)

        def exec_command(slot_info, events):
            manager = WorkerNotificationManager()
            manager.init(rendezvous_addr=addr,
                         rendezvous_port=port,
                         nic=nic,
                         hostname=slot_info.hostname,
                         local_rank=slot_info.local_rank)

            notification_receiver = NotificationReceiver()
            manager.register_listener(notification_receiver)

            driver.record_ready(slot_info.hostname, slot_info.local_rank)

            if slot_info.rank == 0:
                add_host()
            driver.wait_for_available_slots(4)

            if slot_info.rank == 0:
                remove_host()

            # Busy wait for the number of available slots to decrease
            while driver._host_manager.current_hosts.count_available_slots(
            ) > 2:
                time.sleep(0.01)

            rank_results[slot_info.rank] = notification_receiver.events
            return 0, time.time()

        driver.start(np=2, create_worker_fn=exec_command)
        res = driver.get_results().worker_results
        driver.stop()

        assert len(res) == 2
        for name, (exit_code, timestamp) in res.items():
            assert exit_code == 0, name

        assert len(rank_results) == 2
        for rank, events in rank_results.items():
            expected = 2 if rank == 0 else 0
            assert len(events) == expected, rank
            if rank == 0:
                # First update is an add
                assert events[0][1] == HostUpdateResult.added
                # Second update is a removal
                assert events[1][1] == HostUpdateResult.removed

        rendezvous.stop()