def test_mxnet_cluster_config(self):
     config_dict = {
         "worker": [
             "worker0.example.com:2222", "worker1.example.com:2222",
             "worker2.example.com:2222"
         ],
         "server": ["server0.example.com:2222", "server1.example.com:2222"]
     }
     config = MXNetClusterConfig.from_dict(config_dict)
     assert_equal_dict(config_dict, config.to_dict())
Exemple #2
0
    def get_cluster(self):
        cluster_def, _ = self.spec.cluster_def

        job_name = self.pod_manager.get_job_name(task_type=TaskType.MASTER, task_idx=0)
        cluster_config = {
            TaskType.MASTER: [self._get_pod_address(job_name)]
        }

        workers = []
        for i in range(cluster_def.get(TaskType.WORKER, 0)):
            job_name = self.pod_manager.get_job_name(task_type=TaskType.WORKER, task_idx=i)
            workers.append(self._get_pod_address(job_name))

        cluster_config[TaskType.WORKER] = workers

        servers = []
        for i in range(cluster_def.get(TaskType.SERVER, 0)):
            job_name = self.pod_manager.get_job_name(task_type=TaskType.SERVER, task_idx=i)
            servers.append(self._get_pod_address(job_name))

        cluster_config[TaskType.SERVER] = servers

        return MXNetClusterConfig.from_dict(cluster_config).to_dict()
Exemple #3
0
    def get_cluster(self):
        cluster_def, _ = self.spec.cluster_def

        job_name = self.pod_manager.get_job_name(task_type=TaskType.MASTER,
                                                 task_idx=0)
        cluster_config = {TaskType.MASTER: [self._get_pod_address(job_name)]}

        workers = []
        for i in range(cluster_def.get(TaskType.WORKER, 0)):
            job_name = self.pod_manager.get_job_name(task_type=TaskType.WORKER,
                                                     task_idx=i)
            workers.append(self._get_pod_address(job_name))

        cluster_config[TaskType.WORKER] = workers

        servers = []
        for i in range(cluster_def.get(TaskType.SERVER, 0)):
            job_name = self.pod_manager.get_job_name(task_type=TaskType.SERVER,
                                                     task_idx=i)
            servers.append(self._get_pod_address(job_name))

        cluster_config[TaskType.SERVER] = servers

        return MXNetClusterConfig.from_dict(cluster_config).to_dict()