class HorovodRendezvousServerTest(unittest.TestCase): def setUp(self): self.rendezvous_server = HorovodRendezvousServer( server_host="127.0.0.1" ) self.rendezvous_server.start() def test_get_host_plan(self): self.rendezvous_server._worker_name_hosts = [ ("worker-0", "127.0.0.2"), ("worker-1", "127.0.0.3"), ] host_alloc_plan = self.rendezvous_server._get_host_plan() self.assertEqual(host_alloc_plan[0].hostname, "127.0.0.2") self.assertEqual(host_alloc_plan[0].rank, 0) self.assertEqual(host_alloc_plan[0].size, 2) self.assertEqual(host_alloc_plan[1].hostname, "127.0.0.3") self.assertEqual(host_alloc_plan[1].rank, 1) self.assertEqual(host_alloc_plan[1].size, 2) def test_set_worker_hosts(self): worker_name_hosts = [ ("worker-0", "127.0.0.2"), ("worker-1", "127.0.0.3"), ] self.rendezvous_server.set_worker_hosts(worker_name_hosts) rank_0 = self.rendezvous_server.get_worker_host_rank("127.0.0.2") rank_1 = self.rendezvous_server.get_worker_host_rank("127.0.0.3") self.assertEqual(rank_0, 0) self.assertEqual(rank_1, 1) self.assertEqual(self.rendezvous_server._rendezvous_completed, True) self.assertEqual(self.rendezvous_server._rendezvous_id, 1) new_worker_name_hosts = [ ("worker-2", "127.0.0.1"), ("worker-1", "127.0.0.3"), ] self.rendezvous_server.set_worker_hosts(new_worker_name_hosts) self.rendezvous_server._init_rendezvous_server() self.assertEqual(self.rendezvous_server._rendezvous_id, 2) def test_get_attr(self): worker_name_hosts = [ ("worker-0", "127.0.0.2"), ("worker-1", "127.0.0.3"), ] self.rendezvous_server.set_worker_hosts(worker_name_hosts) self.assertEqual( self.rendezvous_server.get_rendezvous_host(), "127.0.0.1" ) self.assertEqual( self.rendezvous_server.get_worker_host_rank("127.0.0.2"), 0 ) self.assertEqual(self.rendezvous_server.get_size(), 2) self.assertEqual(self.rendezvous_server.get_rendezvous_id(), 1)
class HorovodRendezvousServerTest(unittest.TestCase): def setUp(self): self.rendezvous_server = HorovodRendezvousServer( server_host="127.0.0.1" ) self.rendezvous_server.start() def test_get_host_plan(self): self.rendezvous_server._worker_hosts = ["127.0.0.2", "127.0.0.3"] host_alloc_plan = self.rendezvous_server._get_host_plan() self.assertEqual(host_alloc_plan[0].hostname, "127.0.0.2") self.assertEqual(host_alloc_plan[0].rank, 0) self.assertEqual(host_alloc_plan[0].size, 2) self.assertEqual(host_alloc_plan[1].hostname, "127.0.0.3") self.assertEqual(host_alloc_plan[1].rank, 1) self.assertEqual(host_alloc_plan[1].size, 2) def test_set_worker_hosts(self): worker_hosts = ["127.0.0.2", "127.0.0.3"] self.rendezvous_server.set_worker_hosts(worker_hosts) self.assertEqual(self.rendezvous_server._rendezvous_id, 1) new_worker_hosts = ["127.0.0.1", "127.0.0.3"] self.rendezvous_server.set_worker_hosts(new_worker_hosts) self.assertEqual(self.rendezvous_server._rendezvous_id, 2) def test_get_attr(self): worker_hosts = ["127.0.0.2", "127.0.0.3"] self.rendezvous_server.set_worker_hosts(worker_hosts) self.assertEqual( self.rendezvous_server.get_rendezvous_host(), "127.0.0.1" ) self.assertEqual( self.rendezvous_server.get_worker_host_rank("127.0.0.2"), 0 ) self.assertEqual(self.rendezvous_server.get_size(), 2) self.assertEqual(self.rendezvous_server.get_rendezvous_id(), 1)