Example #1
0
 def create_rendezvous_server_if_needed(self, args):
     print("strategy: {}".format(args.distribution_strategy))
     if args.distribution_strategy != DistributionStrategy.ALLREDUCE:
         self.rendezvous_server = None
     else:
         master_ip = os.getenv("MY_POD_IP", "localhost")
         self.rendezvous_server = HorovodRendezvousServer(master_ip)
Example #2
0
    def test_get_comm_rank(self):
        self.master.rendezvous_server = HorovodRendezvousServer(
            server_host="localhost")
        self.master.rendezvous_server.start()
        self.master.rendezvous_server.set_worker_hosts(
            ["172.0.0.1", "172.0.0.2"])

        k8s_client = Mock()
        k8s_client.get_worker_service_address = MagicMock(
            return_value="172.0.0.1:8080")
        self.master.instance_manager = Mock(_k8s_client=k8s_client)
        master_servicer = MasterServicer(3,
                                         evaluation_service=None,
                                         master=self.master)
        request = elasticdl_pb2.GetCommRankRequest()
        request.worker_id = 0
        rank_response = master_servicer.get_comm_rank(request, None)
        self.assertEqual(rank_response.world_size, 2)
        self.assertEqual(rank_response.rank_id, 0)
        self.assertEqual(rank_response.rendezvous_id, 1)
Example #3
0
    def test_get_comm_rank(self):
        self.master.rendezvous_server = HorovodRendezvousServer(
            server_host="localhost")
        self.master.rendezvous_server.start()
        self.master.rendezvous_server.set_worker_hosts([
            ("worker-0", "172.0.0.1"), ("worker-1", "172.0.0.2")
        ])

        mock_instance_manager = Mock()
        mock_instance_manager.get_worker_pod_ip = MagicMock(
            return_value="172.0.0.1")
        self.master.instance_manager = mock_instance_manager
        master_servicer = MasterServicer(3,
                                         evaluation_service=None,
                                         master=self.master)
        request = elasticdl_pb2.GetCommRankRequest()
        request.worker_id = 0
        rank_response = master_servicer.get_comm_rank(request, None)
        self.assertEqual(rank_response.world_size, 2)
        self.assertEqual(rank_response.rank_id, 0)
        self.assertEqual(rank_response.rendezvous_id, 1)
Example #4
0
    def test_get_comm_rank(self):
        self.master.rendezvous_server = HorovodRendezvousServer(
            server_host="localhost")
        self.master.rendezvous_server.start()
        self.master.rendezvous_server.add_worker("172.0.0.1")
        self.master.rendezvous_server.add_worker("172.0.0.2")

        mock_instance_manager = Mock()
        mock_instance_manager.get_worker_pod_ip = MagicMock(
            return_value="172.0.0.1")
        self.master.instance_manager = mock_instance_manager
        master_servicer = MasterServicer(
            self.master.task_manager,
            self.master.instance_manager,
            self.master.rendezvous_server,
            None,
        )
        request = elasticai_api_pb2.GetCommRankRequest()
        request.worker_host = "172.0.0.1"
        rank_response = master_servicer.get_comm_rank(request, None)
        self.assertEqual(rank_response.world_size, 2)
        self.assertEqual(rank_response.rank_id, 0)
        self.assertEqual(rank_response.rendezvous_id, 1)
Example #5
0
    def __init__(self, args):
        self.logger = get_logger("master", level=args.log_level.upper())

        self.num_ps_pods = args.num_ps_pods
        self.checkpoint_output_path = args.checkpoint_dir
        self.distribution_strategy = args.distribution_strategy

        # Master addr
        master_ip = os.getenv("MY_POD_IP", "localhost")
        self.master_addr = "%s:%d" % (master_ip, args.port)
        self.job_type = Master._get_job_type(args)
        self.rendezvous_server = None
        if self.distribution_strategy == DistributionStrategy.ALLREDUCE:
            self.rendezvous_server = HorovodRendezvousServer(master_ip)

        # Initialize TensorBoard service if requested
        self.tb_service = self._create_tensorboard_service(
            args.tensorboard_log_dir, master_ip
        )
        if self.tb_service:
            self.tb_client = TensorBoardClient(
                job_name=args.job_name,
                image_name=args.worker_image,
                namespace=args.namespace,
            )

        # Initialize the components from the model definition
        self.model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)
        ).__dict__
        self.model_inst = load_model_from_module(
            args.model_def, self.model_module, args.model_params
        )
        self.optimizer = self.model_module[args.optimizer]()
        self._create_data_reader_fn = create_data_reader
        if args.custom_data_reader in self.model_module:
            self._create_data_reader_fn = self.model_module[
                args.custom_data_reader
            ]

        # Initialize the callbacks
        self.callbacks_list = load_callbacks_from_module(
            args.callbacks, self.model_module
        )
        self.callbacks_list.set_model(self.model_inst)
        set_callback_parameters(
            self.callbacks_list,
            batch_size=args.minibatch_size,
            saved_model_path=args.output,
            checkpoint_path=args.checkpoint_dir,
        )
        self._set_completed_steps_by_checkpoint(args.checkpoint_dir_for_init)

        # Start task queue
        records_per_task = args.minibatch_size * args.num_minibatches_per_task
        self.task_d = _make_task_dispatcher(
            args.training_data,
            args.validation_data,
            args.prediction_data,
            records_per_task,
            args.num_epochs,
            args.data_reader_params,
            self._create_data_reader_fn,
            self.callbacks_list,
        )

        self.task_d.add_deferred_callback_create_train_end_task()
        self.evaluation_service = self._create_evaluation_service(args)

        # Initialize instance manager
        self.instance_manager = self._create_instance_manager(args)

        # Initialize master service
        self.master_servicer, self.server = self._create_master_service(args)

        self._should_stop = False
        self._exit_code = 0
        threading.Thread(
            target=self._check_timeout_tasks,
            name="check_timeout_tasks",
            daemon=True,
        ).start()
Example #6
0
class Master(object):
    def __init__(self, args):
        self.logger = get_logger("master", level=args.log_level.upper())

        self.num_ps_pods = args.num_ps_pods
        self.checkpoint_output_path = args.checkpoint_dir
        self.distribution_strategy = args.distribution_strategy

        # Master addr
        master_ip = os.getenv("MY_POD_IP", "localhost")
        self.master_addr = "%s:%d" % (master_ip, args.port)
        self.job_type = Master._get_job_type(args)
        self.rendezvous_server = None
        if self.distribution_strategy == DistributionStrategy.ALLREDUCE:
            self.rendezvous_server = HorovodRendezvousServer(master_ip)

        # Initialize TensorBoard service if requested
        self.tb_service = self._create_tensorboard_service(
            args.tensorboard_log_dir, master_ip
        )
        if self.tb_service:
            self.tb_client = TensorBoardClient(
                job_name=args.job_name,
                image_name=args.worker_image,
                namespace=args.namespace,
            )

        # Initialize the components from the model definition
        self.model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)
        ).__dict__
        self.model_inst = load_model_from_module(
            args.model_def, self.model_module, args.model_params
        )
        self.optimizer = self.model_module[args.optimizer]()
        self._create_data_reader_fn = create_data_reader
        if args.custom_data_reader in self.model_module:
            self._create_data_reader_fn = self.model_module[
                args.custom_data_reader
            ]

        # Initialize the callbacks
        self.callbacks_list = load_callbacks_from_module(
            args.callbacks, self.model_module
        )
        self.callbacks_list.set_model(self.model_inst)
        set_callback_parameters(
            self.callbacks_list,
            batch_size=args.minibatch_size,
            saved_model_path=args.output,
            checkpoint_path=args.checkpoint_dir,
        )
        self._set_completed_steps_by_checkpoint(args.checkpoint_dir_for_init)

        # Start task queue
        records_per_task = args.minibatch_size * args.num_minibatches_per_task
        self.task_d = _make_task_dispatcher(
            args.training_data,
            args.validation_data,
            args.prediction_data,
            records_per_task,
            args.num_epochs,
            args.data_reader_params,
            self._create_data_reader_fn,
            self.callbacks_list,
        )

        self.task_d.add_deferred_callback_create_train_end_task()
        self.evaluation_service = self._create_evaluation_service(args)

        # Initialize instance manager
        self.instance_manager = self._create_instance_manager(args)

        # Initialize master service
        self.master_servicer, self.server = self._create_master_service(args)

        self._should_stop = False
        self._exit_code = 0
        threading.Thread(
            target=self._check_timeout_tasks,
            name="check_timeout_tasks",
            daemon=True,
        ).start()

    def _set_completed_steps_by_checkpoint(self, checkpoint_dir_for_init):
        if not checkpoint_dir_for_init:
            return

        if not CheckpointSaver.check_checkpoint_valid(checkpoint_dir_for_init):
            raise ValueError(
                "Invalid checkpoint directory {}".format(
                    checkpoint_dir_for_init
                )
            )

        model_verion = CheckpointSaver.get_version_from_checkpoint(
            checkpoint_dir_for_init
        )
        for callback in self.callbacks_list.callbacks:
            if isinstance(callback, MaxStepsStopping):
                callback.set_completed_steps(model_verion)

    def request_stop(self, err_msg=None):
        """Request master to quit"""
        self._should_stop = True
        if err_msg:
            self.logger.error(err_msg)
            # TODO (chengfu.wcy) create meaningful status codes
            self._exit_code = -1

    def prepare(self):
        """
        Start the components one by one. Make sure that it is ready to run.
        """
        # Start the evaluation service if requested
        if self.evaluation_service:
            self.logger.info("Starting evaluation service")
            self.evaluation_service.start()
            self.logger.info("Evaluation service started")

        # Start the master GRPC server
        self.logger.info("Starting master RPC server")
        self.server.start()
        self.logger.info("Master RPC server started")

        # Start the worker manager if requested
        if self.instance_manager:
            self.instance_manager.update_status(InstanceManagerStatus.PENDING)
            if self.distribution_strategy == DistributionStrategy.ALLREDUCE:
                # Start rendezvous server for workers to initialize Horovod
                self.rendezvous_server.start()
            else:
                self.instance_manager.start_parameter_servers()
            self.instance_manager.start_workers()
            self.instance_manager.update_status(InstanceManagerStatus.RUNNING)

        # Start TensorBoard k8s Service if requested
        if self.tb_service and self.tb_client:
            self.logger.info("Starting tensorboard service")
            self.tb_service.start()
            self.tb_client.start_tensorboard_service()
            self.logger.info("Tensorboard service started")

    def run(self):
        """
        The main loop of master.
        Dispatch the tasks to the workers until all the tasks are completed.
        """
        try:
            while True:
                if self.instance_manager.all_workers_failed:
                    raise Exception(
                        "All workers fail with unrecoverable errors"
                    )
                    break
                if self.task_d.finished():
                    if self.instance_manager:
                        self.instance_manager.update_status(
                            InstanceManagerStatus.FINISHED
                        )
                    break
                if self._should_stop:
                    break
                time.sleep(30)
        except KeyboardInterrupt:
            self.logger.warning("Server stopping")
        finally:
            self._stop()
        return self._exit_code

    def _stop(self):
        """
        Stop all the components.
        Make sure that the created services and components are shut down.
        """
        self.logger.info("Stopping master")

        if self.evaluation_service:
            self.logger.info("Stopping evaluation service")
            self.evaluation_service.stop()
            self.logger.info("Evaluation service stopped")

        self.logger.info("Stopping RPC server")
        self.server.stop(None)  # grace = None
        self.logger.info("RPC server stopped")

        # Keep TensorBoard running when all the tasks are finished
        if self.tb_service:
            self.logger.info(
                "All tasks finished. Keeping TensorBoard service running..."
            )
            while True:
                if self.tb_service.is_active():
                    time.sleep(10)
                else:
                    self.logger.warning(
                        "Unable to keep TensorBoard running. "
                        "It has already terminated"
                    )
                    break
        self.logger.info("Master stopped")

    @staticmethod
    def _get_job_type(args):
        if all(
            (
                args.training_data,
                args.validation_data,
                args.evaluation_throttle_secs or args.evaluation_steps,
            )
        ):
            job_type = JobType.TRAINING_WITH_EVALUATION
        elif all(
            (
                args.validation_data,
                not args.training_data,
                not args.prediction_data,
            )
        ):
            job_type = JobType.EVALUATION_ONLY
        elif all(
            (
                args.prediction_data,
                not args.validation_data,
                not args.training_data,
            )
        ):
            job_type = JobType.PREDICTION_ONLY
        else:
            job_type = JobType.TRAINING_ONLY

        return job_type

    def _create_tensorboard_service(self, tensorboard_log_dir, master_ip):
        tb_service = None
        if tensorboard_log_dir:
            self.logger.info(
                "Creating TensorBoard service with log directory %s",
                tensorboard_log_dir,
            )
            # Start TensorBoard CLI
            tb_service = TensorboardService(tensorboard_log_dir, master_ip)

        return tb_service

    def _create_evaluation_service(self, args):
        evaluation_service = None
        if (
            self.job_type == JobType.TRAINING_WITH_EVALUATION
            or self.job_type == JobType.EVALUATION_ONLY
        ):
            self.logger.info(
                "Creating evaluation service with throttle seconds %d "
                " and evaluation steps %d",
                args.evaluation_throttle_secs,
                args.evaluation_steps,
            )
            evaluation_service = EvaluationService(
                self.tb_service,
                self.task_d,
                args.evaluation_start_delay_secs,
                args.evaluation_throttle_secs,
                args.evaluation_steps,
                self.job_type == JobType.EVALUATION_ONLY,
                self.model_module[args.eval_metrics_fn],
            )
            self.task_d.set_evaluation_service(evaluation_service)

        return evaluation_service

    def _create_master_service(self, args):
        self.logger.info("Creating master service")
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=64),
            options=[
                ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH),
                (
                    "grpc.max_receive_message_length",
                    GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
                ),
            ],
        )
        master_servicer = MasterServicer(
            args.minibatch_size,
            evaluation_service=self.evaluation_service,
            master=self,
        )
        elasticdl_pb2_grpc.add_MasterServicer_to_server(
            master_servicer, server
        )
        server.add_insecure_port("[::]:{}".format(args.port))
        self.logger.info("The port of the master server is: %d", args.port)

        return master_servicer, server

    def _create_instance_manager(self, args):
        instance_manager = None

        container_command = ["/bin/bash"]
        if args.num_workers:
            assert args.worker_image, "Worker image cannot be empty"

            worker_client_command = (
                BashCommandTemplate.SET_PIPEFAIL
                + " python -m elasticdl.python.worker.main"
            )
            worker_args = [
                "--master_addr",
                self.master_addr,
                "--job_type",
                self.job_type,
            ]
            worker_args.extend(
                build_arguments_from_parsed_result(args, filter_args=["envs"])
            )
            worker_args = wrap_python_args_with_string(worker_args)
            worker_args.insert(0, worker_client_command)

            if args.use_go_ps:
                opt_type, opt_args = get_optimizer_info(self.optimizer)
                ps_command = "elasticdl_ps"
                ps_command_args = [
                    "-job_name=" + args.job_name,
                    "-namespace=" + args.namespace,
                    "-master_addr=" + self.master_addr,
                    "-port=2222",
                    "-use_async=" + ("true" if args.use_async else "false"),
                    "-grads_to_wait=" + str(args.grads_to_wait),
                    "-lr_staleness_modulation="
                    + ("true" if args.lr_staleness_modulation else "false"),
                    "-sync_version_tolerance="
                    + str(args.sync_version_tolerance),
                    "-evaluation_steps=" + str(args.evaluation_steps),
                    "-num_ps_pods=" + str(args.num_ps_pods),
                    "-num_workers=" + str(args.num_workers),
                    "-checkpoint_dir=" + str(args.checkpoint_dir),
                    "-checkpoint_steps=" + str(args.checkpoint_steps),
                    "-keep_checkpoint_max=" + str(args.keep_checkpoint_max),
                    "-checkpoint_dir_for_init="
                    + str(args.checkpoint_dir_for_init),
                    "-opt_type=" + opt_type,
                    "-opt_args=" + opt_args,
                ]
                ps_command_args = wrap_go_args_with_string(ps_command_args)
                # Execute source /root/.bashrc to add the file path
                # of `elasticdl_ps` into the PATH environment variable.
                ps_args = [
                    "source",
                    "/root/.bashrc_elasticdl",
                    "&&",
                    ps_command,
                ]
                ps_args.extend(ps_command_args)
            else:
                ps_command = (
                    BashCommandTemplate.SET_PIPEFAIL
                    + " python -m elasticdl.python.ps.main"
                )
                ps_command_args = [
                    "--grads_to_wait",
                    str(args.grads_to_wait),
                    "--lr_staleness_modulation",
                    str(args.lr_staleness_modulation),
                    "--sync_version_tolerance",
                    str(args.sync_version_tolerance),
                    "--use_async",
                    str(args.use_async),
                    "--model_zoo",
                    args.model_zoo,
                    "--model_def",
                    args.model_def,
                    "--job_name",
                    args.job_name,
                    "--port",
                    "2222",
                    "--master_addr",
                    self.master_addr,
                    "--namespace",
                    args.namespace,
                    "--evaluation_steps",
                    str(args.evaluation_steps),
                    "--checkpoint_dir",
                    str(args.checkpoint_dir),
                    "--checkpoint_steps",
                    str(args.checkpoint_steps),
                    "--keep_checkpoint_max",
                    str(args.keep_checkpoint_max),
                    "--num_ps_pods",
                    str(args.num_ps_pods),
                    "--checkpoint_dir_for_init",
                    str(args.checkpoint_dir_for_init),
                    "--num_workers",
                    str(args.num_workers),
                    "--log_level",
                    str(args.log_level),
                    "--minibatch_size",
                    str(args.minibatch_size),
                    "--num_minibatches_per_task",
                    str(args.num_minibatches_per_task),
                ]
                ps_args = wrap_python_args_with_string(ps_command_args)
                ps_args.insert(0, ps_command)

            worker_args = ["-c", " ".join(worker_args)]
            ps_args = ["-c", " ".join(ps_args)]

            env_dict = parse_envs(args.envs)
            env = []
            for key in env_dict:
                env.append(V1EnvVar(name=key, value=env_dict[key]))

            kwargs = get_dict_from_params_str(args.aux_params)
            disable_relaunch = kwargs.get("disable_relaunch", False)
            cluster_spec = self._get_image_cluster_spec(args.cluster_spec)

            instance_manager = InstanceManager(
                self.task_d,
                rendezvous_server=self.rendezvous_server,
                job_name=args.job_name,
                image_name=args.worker_image,
                worker_command=container_command,
                worker_args=worker_args,
                namespace=args.namespace,
                num_workers=args.num_workers,
                worker_resource_request=args.worker_resource_request,
                worker_resource_limit=args.worker_resource_limit,
                worker_pod_priority=args.worker_pod_priority,
                num_ps=args.num_ps_pods,
                ps_command=container_command,
                ps_args=ps_args,
                ps_resource_request=args.ps_resource_request,
                ps_resource_limit=args.ps_resource_limit,
                ps_pod_priority=args.ps_pod_priority,
                volume=args.volume,
                image_pull_policy=args.image_pull_policy,
                restart_policy=args.restart_policy,
                cluster_spec=cluster_spec,
                envs=env,
                disable_relaunch=disable_relaunch,
                log_file_path=args.log_file_path,
            )

        return instance_manager

    def _get_image_cluster_spec(self, cluster_spec):
        if cluster_spec:
            filename = os.path.basename(cluster_spec)
            image_cluster_spec = os.path.join(
                ClusterSpecConfig.CLUSTER_SPEC_DIR, filename
            )
            return image_cluster_spec
        return cluster_spec

    def _check_timeout_tasks(self):
        while True:
            doing_tasks = self.task_d._doing.copy()
            cur_time = time.time()
            avg_time = self.master_servicer.get_average_task_complete_time()
            for task_id, (worker_id, task, start_time) in doing_tasks.items():
                if task.type == elasticdl_pb2.TRAINING:
                    start_time = self.master_servicer.get_worker_liveness_time(
                        worker_id
                    )
                if task.type in [
                    elasticdl_pb2.TRAINING,
                    elasticdl_pb2.EVALUATION,
                ]:
                    if (cur_time - start_time) > 3 * avg_time[task.type]:
                        self.logger.info(
                            "worker %d timeout, relaunch it" % worker_id
                        )
                        self.task_d.recover_tasks(worker_id)
                        # TODO: save worker logs before remove it
                        self.instance_manager._remove_worker(worker_id)
                        break
            time.sleep(30)
 def setUp(self):
     self.rendezvous_server = HorovodRendezvousServer(
         server_host="127.0.0.1"
     )
     self.rendezvous_server.start()
class HorovodRendezvousServerTest(unittest.TestCase):
    def setUp(self):
        self.rendezvous_server = HorovodRendezvousServer(
            server_host="127.0.0.1"
        )
        self.rendezvous_server.start()

    def test_get_host_plan(self):
        self.rendezvous_server._worker_name_hosts = [
            ("worker-0", "127.0.0.2"),
            ("worker-1", "127.0.0.3"),
        ]
        host_alloc_plan = self.rendezvous_server._get_host_plan()
        self.assertEqual(host_alloc_plan[0].hostname, "127.0.0.2")
        self.assertEqual(host_alloc_plan[0].rank, 0)
        self.assertEqual(host_alloc_plan[0].size, 2)
        self.assertEqual(host_alloc_plan[1].hostname, "127.0.0.3")
        self.assertEqual(host_alloc_plan[1].rank, 1)
        self.assertEqual(host_alloc_plan[1].size, 2)

    def test_set_worker_hosts(self):
        worker_name_hosts = [
            ("worker-0", "127.0.0.2"),
            ("worker-1", "127.0.0.3"),
        ]
        self.rendezvous_server.set_worker_hosts(worker_name_hosts)
        rank_0 = self.rendezvous_server.get_worker_host_rank("127.0.0.2")
        rank_1 = self.rendezvous_server.get_worker_host_rank("127.0.0.3")
        self.assertEqual(rank_0, 0)
        self.assertEqual(rank_1, 1)
        self.assertEqual(self.rendezvous_server._rendezvous_completed, True)
        self.assertEqual(self.rendezvous_server._rendezvous_id, 1)

        new_worker_name_hosts = [
            ("worker-2", "127.0.0.1"),
            ("worker-1", "127.0.0.3"),
        ]
        self.rendezvous_server.set_worker_hosts(new_worker_name_hosts)
        self.rendezvous_server._init_rendezvous_server()
        self.assertEqual(self.rendezvous_server._rendezvous_id, 2)

    def test_get_attr(self):
        worker_name_hosts = [
            ("worker-0", "127.0.0.2"),
            ("worker-1", "127.0.0.3"),
        ]
        self.rendezvous_server.set_worker_hosts(worker_name_hosts)
        self.assertEqual(
            self.rendezvous_server.get_rendezvous_host(), "127.0.0.1"
        )
        self.assertEqual(
            self.rendezvous_server.get_worker_host_rank("127.0.0.2"), 0
        )
        self.assertEqual(self.rendezvous_server.get_size(), 2)
        self.assertEqual(self.rendezvous_server.get_rendezvous_id(), 1)
class HorovodRendezvousServerTest(unittest.TestCase):
    def setUp(self):
        self.rendezvous_server = HorovodRendezvousServer(
            server_host="127.0.0.1"
        )
        self.rendezvous_server.start()

    def test_get_host_plan(self):
        self.rendezvous_server._worker_hosts = ["127.0.0.2", "127.0.0.3"]
        host_alloc_plan = self.rendezvous_server._get_host_plan()
        self.assertEqual(host_alloc_plan[0].hostname, "127.0.0.2")
        self.assertEqual(host_alloc_plan[0].rank, 0)
        self.assertEqual(host_alloc_plan[0].size, 2)
        self.assertEqual(host_alloc_plan[1].hostname, "127.0.0.3")
        self.assertEqual(host_alloc_plan[1].rank, 1)
        self.assertEqual(host_alloc_plan[1].size, 2)

    def test_set_worker_hosts(self):
        worker_hosts = ["127.0.0.2", "127.0.0.3"]
        self.rendezvous_server.set_worker_hosts(worker_hosts)
        self.assertEqual(self.rendezvous_server._rendezvous_id, 1)

        new_worker_hosts = ["127.0.0.1", "127.0.0.3"]
        self.rendezvous_server.set_worker_hosts(new_worker_hosts)
        self.assertEqual(self.rendezvous_server._rendezvous_id, 2)

    def test_get_attr(self):
        worker_hosts = ["127.0.0.2", "127.0.0.3"]
        self.rendezvous_server.set_worker_hosts(worker_hosts)
        self.assertEqual(
            self.rendezvous_server.get_rendezvous_host(), "127.0.0.1"
        )
        self.assertEqual(
            self.rendezvous_server.get_worker_host_rank("127.0.0.2"), 0
        )
        self.assertEqual(self.rendezvous_server.get_size(), 2)
        self.assertEqual(self.rendezvous_server.get_rendezvous_id(), 1)