Пример #1
0
    def test_get_comm_rank(self):
        self.master.rendezvous_server = HorovodRendezvousServer(
            server_host="localhost")
        self.master.rendezvous_server.start()
        self.master.rendezvous_server.set_worker_hosts(
            ["172.0.0.1", "172.0.0.2"])

        k8s_client = Mock()
        k8s_client.get_worker_service_address = MagicMock(
            return_value="172.0.0.1:8080")
        self.master.instance_manager = Mock(_k8s_client=k8s_client)
        master_servicer = MasterServicer(3,
                                         evaluation_service=None,
                                         master=self.master)
        request = elasticdl_pb2.GetCommRankRequest()
        request.worker_id = 0
        rank_response = master_servicer.get_comm_rank(request, None)
        self.assertEqual(rank_response.world_size, 2)
        self.assertEqual(rank_response.rank_id, 0)
        self.assertEqual(rank_response.rendezvous_id, 1)
Пример #2
0
    def test_get_comm_rank(self):
        self.master.rendezvous_server = HorovodRendezvousServer(
            server_host="localhost")
        self.master.rendezvous_server.start()
        self.master.rendezvous_server.set_worker_hosts([
            ("worker-0", "172.0.0.1"), ("worker-1", "172.0.0.2")
        ])

        mock_instance_manager = Mock()
        mock_instance_manager.get_worker_pod_ip = MagicMock(
            return_value="172.0.0.1")
        self.master.instance_manager = mock_instance_manager
        master_servicer = MasterServicer(3,
                                         evaluation_service=None,
                                         master=self.master)
        request = elasticdl_pb2.GetCommRankRequest()
        request.worker_id = 0
        rank_response = master_servicer.get_comm_rank(request, None)
        self.assertEqual(rank_response.world_size, 2)
        self.assertEqual(rank_response.rank_id, 0)
        self.assertEqual(rank_response.rendezvous_id, 1)
Пример #3
0
    def test_get_comm_rank(self):
        self.master.rendezvous_server = HorovodRendezvousServer(
            server_host="localhost")
        self.master.rendezvous_server.start()
        self.master.rendezvous_server.add_worker("172.0.0.1")
        self.master.rendezvous_server.add_worker("172.0.0.2")

        mock_instance_manager = Mock()
        mock_instance_manager.get_worker_pod_ip = MagicMock(
            return_value="172.0.0.1")
        self.master.instance_manager = mock_instance_manager
        master_servicer = MasterServicer(
            self.master.task_manager,
            self.master.instance_manager,
            self.master.rendezvous_server,
            None,
        )
        request = elasticai_api_pb2.GetCommRankRequest()
        request.worker_host = "172.0.0.1"
        rank_response = master_servicer.get_comm_rank(request, None)
        self.assertEqual(rank_response.world_size, 2)
        self.assertEqual(rank_response.rank_id, 0)
        self.assertEqual(rank_response.rendezvous_id, 1)