class HorovodRendezvousServerTest(unittest.TestCase):
    def setUp(self):
        self.rendezvous_server = HorovodRendezvousServer(
            server_host="127.0.0.1"
        )
        self.rendezvous_server.start()

    def test_get_host_plan(self):
        self.rendezvous_server._worker_name_hosts = [
            ("worker-0", "127.0.0.2"),
            ("worker-1", "127.0.0.3"),
        ]
        host_alloc_plan = self.rendezvous_server._get_host_plan()
        self.assertEqual(host_alloc_plan[0].hostname, "127.0.0.2")
        self.assertEqual(host_alloc_plan[0].rank, 0)
        self.assertEqual(host_alloc_plan[0].size, 2)
        self.assertEqual(host_alloc_plan[1].hostname, "127.0.0.3")
        self.assertEqual(host_alloc_plan[1].rank, 1)
        self.assertEqual(host_alloc_plan[1].size, 2)

    def test_set_worker_hosts(self):
        worker_name_hosts = [
            ("worker-0", "127.0.0.2"),
            ("worker-1", "127.0.0.3"),
        ]
        self.rendezvous_server.set_worker_hosts(worker_name_hosts)
        rank_0 = self.rendezvous_server.get_worker_host_rank("127.0.0.2")
        rank_1 = self.rendezvous_server.get_worker_host_rank("127.0.0.3")
        self.assertEqual(rank_0, 0)
        self.assertEqual(rank_1, 1)
        self.assertEqual(self.rendezvous_server._rendezvous_completed, True)
        self.assertEqual(self.rendezvous_server._rendezvous_id, 1)

        new_worker_name_hosts = [
            ("worker-2", "127.0.0.1"),
            ("worker-1", "127.0.0.3"),
        ]
        self.rendezvous_server.set_worker_hosts(new_worker_name_hosts)
        self.rendezvous_server._init_rendezvous_server()
        self.assertEqual(self.rendezvous_server._rendezvous_id, 2)

    def test_get_attr(self):
        worker_name_hosts = [
            ("worker-0", "127.0.0.2"),
            ("worker-1", "127.0.0.3"),
        ]
        self.rendezvous_server.set_worker_hosts(worker_name_hosts)
        self.assertEqual(
            self.rendezvous_server.get_rendezvous_host(), "127.0.0.1"
        )
        self.assertEqual(
            self.rendezvous_server.get_worker_host_rank("127.0.0.2"), 0
        )
        self.assertEqual(self.rendezvous_server.get_size(), 2)
        self.assertEqual(self.rendezvous_server.get_rendezvous_id(), 1)
Exemplo n.º 2
0
class HorovodRendezvousServerTest(unittest.TestCase):
    def setUp(self):
        self.rendezvous_server = HorovodRendezvousServer(
            server_host="127.0.0.1"
        )
        self.rendezvous_server.start()

    def test_get_host_plan(self):
        self.rendezvous_server._worker_hosts = ["127.0.0.2", "127.0.0.3"]
        host_alloc_plan = self.rendezvous_server._get_host_plan()
        self.assertEqual(host_alloc_plan[0].hostname, "127.0.0.2")
        self.assertEqual(host_alloc_plan[0].rank, 0)
        self.assertEqual(host_alloc_plan[0].size, 2)
        self.assertEqual(host_alloc_plan[1].hostname, "127.0.0.3")
        self.assertEqual(host_alloc_plan[1].rank, 1)
        self.assertEqual(host_alloc_plan[1].size, 2)

    def test_set_worker_hosts(self):
        worker_hosts = ["127.0.0.2", "127.0.0.3"]
        self.rendezvous_server.set_worker_hosts(worker_hosts)
        self.assertEqual(self.rendezvous_server._rendezvous_id, 1)

        new_worker_hosts = ["127.0.0.1", "127.0.0.3"]
        self.rendezvous_server.set_worker_hosts(new_worker_hosts)
        self.assertEqual(self.rendezvous_server._rendezvous_id, 2)

    def test_get_attr(self):
        worker_hosts = ["127.0.0.2", "127.0.0.3"]
        self.rendezvous_server.set_worker_hosts(worker_hosts)
        self.assertEqual(
            self.rendezvous_server.get_rendezvous_host(), "127.0.0.1"
        )
        self.assertEqual(
            self.rendezvous_server.get_worker_host_rank("127.0.0.2"), 0
        )
        self.assertEqual(self.rendezvous_server.get_size(), 2)
        self.assertEqual(self.rendezvous_server.get_rendezvous_id(), 1)