def test_get_comm_rank(self): self.master.rendezvous_server = HorovodRendezvousServer( server_host="localhost") self.master.rendezvous_server.start() self.master.rendezvous_server.set_worker_hosts( ["172.0.0.1", "172.0.0.2"]) k8s_client = Mock() k8s_client.get_worker_service_address = MagicMock( return_value="172.0.0.1:8080") self.master.instance_manager = Mock(_k8s_client=k8s_client) master_servicer = MasterServicer(3, evaluation_service=None, master=self.master) request = elasticdl_pb2.GetCommRankRequest() request.worker_id = 0 rank_response = master_servicer.get_comm_rank(request, None) self.assertEqual(rank_response.world_size, 2) self.assertEqual(rank_response.rank_id, 0) self.assertEqual(rank_response.rendezvous_id, 1)
def test_get_comm_rank(self): self.master.rendezvous_server = HorovodRendezvousServer( server_host="localhost") self.master.rendezvous_server.start() self.master.rendezvous_server.set_worker_hosts([ ("worker-0", "172.0.0.1"), ("worker-1", "172.0.0.2") ]) mock_instance_manager = Mock() mock_instance_manager.get_worker_pod_ip = MagicMock( return_value="172.0.0.1") self.master.instance_manager = mock_instance_manager master_servicer = MasterServicer(3, evaluation_service=None, master=self.master) request = elasticdl_pb2.GetCommRankRequest() request.worker_id = 0 rank_response = master_servicer.get_comm_rank(request, None) self.assertEqual(rank_response.world_size, 2) self.assertEqual(rank_response.rank_id, 0) self.assertEqual(rank_response.rendezvous_id, 1)
def test_get_comm_rank(self): self.master.rendezvous_server = HorovodRendezvousServer( server_host="localhost") self.master.rendezvous_server.start() self.master.rendezvous_server.add_worker("172.0.0.1") self.master.rendezvous_server.add_worker("172.0.0.2") mock_instance_manager = Mock() mock_instance_manager.get_worker_pod_ip = MagicMock( return_value="172.0.0.1") self.master.instance_manager = mock_instance_manager master_servicer = MasterServicer( self.master.task_manager, self.master.instance_manager, self.master.rendezvous_server, None, ) request = elasticai_api_pb2.GetCommRankRequest() request.worker_host = "172.0.0.1" rank_response = master_servicer.get_comm_rank(request, None) self.assertEqual(rank_response.world_size, 2) self.assertEqual(rank_response.rank_id, 0) self.assertEqual(rank_response.rendezvous_id, 1)