def _initialize_train_process_comm(self) -> None: check.true(self.hvd_config.use) srv_pub_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_1 + self.env.det_trial_unique_port_offset) srv_pull_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_2 + self.env.det_trial_unique_port_offset) if self.is_chief: logging.debug( f"Chief setting up server with ports {srv_pub_port}/{srv_pull_port}." ) self.train_process_comm_chief = ipc.ZMQBroadcastServer( num_connections=self.env.experiment_config.slots_per_trial() - 1, pub_port=srv_pub_port, pull_port=srv_pull_port, ) else: chief_ip_address = self.rendezvous_info.get_ip_addresses()[0] logging.debug(f"Non-Chief {hvd.rank()} setting up comm to " f"{chief_ip_address} w/ ports " f"{srv_pub_port}/{srv_pull_port}.") self.train_process_comm_worker = ipc.ZMQBroadcastClient( srv_pub_url=f"tcp://{chief_ip_address}:{srv_pub_port}", srv_pull_url=f"tcp://{chief_ip_address}:{srv_pull_port}", )
def test_broadcast_server_client() -> None: num_subprocs = 3 with ipc.ZMQBroadcastServer( num_connections=num_subprocs) as broadcast_server: pub_url = f"tcp://localhost:{broadcast_server.get_pub_port()}" pull_url = f"tcp://localhost:{broadcast_server.get_pull_port()}" msgs = list(range(10)) with SubprocGroup( BroadcastClientSubproc(i, num_subprocs, pub_url, pull_url, msgs) for i in range(num_subprocs)) as subprocs: def health_check() -> None: assert all(subproc.is_alive() for subproc in subprocs) for subproc in subprocs: assert subproc.is_alive() gathered, _ = broadcast_server.gather_with_polling(health_check) assert all(isinstance(g, ipc.ConnectedMessage) for g in gathered) for msg in msgs: broadcast_server.broadcast(msg) gathered, _ = broadcast_server.gather_with_polling( health_check) assert all(g == 2 * msg for g in gathered)
def __init__( self, env: det.EnvContext, workloads: workload.Stream, load_path: Optional[pathlib.Path], rendezvous_info: det.RendezvousInfo, hvd_config: horovod.HorovodContext, python_subprocess_entrypoint: Optional[str] = None, ) -> None: self.env = env self.workloads = workloads self.load_path = load_path self.rendezvous_info = rendezvous_info self.hvd_config = hvd_config self._python_subprocess_entrypoint = python_subprocess_entrypoint self.num_gpus = len(self.env.container_gpus) self.debug = self.env.experiment_config.debug_enabled() # Horovod will have a separate training process for each GPU. number_of_worker_processes = self.num_gpus if self.hvd_config.use else 1 # Step 1: Establish the server for communicating with the subprocess. self.broadcast_server = ipc.ZMQBroadcastServer(num_connections=number_of_worker_processes) # Step 2: Configure the per-machine WorkerProcessContext. self._init_worker_process_env() self.is_chief_machine = self.rendezvous_info.get_rank() == 0 chief_addr = self.rendezvous_info.get_ip_addresses()[0] chief_port = self.rendezvous_info.get_ports()[0] if self.is_chief_machine: # Step 3 (chief): Wait for any peer machines to launch sshd, then launch horovodrun. if self.rendezvous_info.get_size() > 1: with ipc.ZMQServer(ports=[chief_port], num_connections=1) as server: num_peers = self.rendezvous_info.get_size() - 1 responses = server.barrier(num_connections=num_peers, timeout=20) if len(responses) < num_peers: raise AssertionError( f"Chief received sshd ready signal only from {len(responses)} " f"of {num_peers} machines." ) logging.debug("Chief finished sshd barrier.") if self.hvd_config.use: self._subproc = self._launch_horovodrun() else: self._subproc = self._launch_python_subprocess() else: # Step 3 (non-chief): launch sshd, wait for it to come up, then signal to the chief. self._subproc = self._launch_sshd() self._wait_for_sshd_to_start() with ipc.ZMQClient(chief_addr, chief_port) as client: client.barrier()
def test_broadcast_server_client() -> None: num_subprocs = 3 with ipc.ZMQBroadcastServer( num_connections=num_subprocs) as broadcast_server: pub_url = f"tcp://localhost:{broadcast_server.get_pub_port()}" pull_url = f"tcp://localhost:{broadcast_server.get_pull_port()}" msgs = list(range(10)) with SubprocGroup( BroadcastClientSubproc(i, num_subprocs, pub_url, pull_url, msgs) for i in range(num_subprocs)): broadcast_server.safe_start() for msg in msgs: broadcast_server.broadcast(msg) gathered = broadcast_server.gather() assert all(g == 2 * msg for g in gathered)
def __init__( self, env: det.EnvContext, hvd_config: horovod.HorovodContext, rendezvous_info: det.RendezvousInfo, ) -> None: self._env = env self._hvd_config = hvd_config self._rendezvous_info = rendezvous_info if self._hvd_config.use: self._is_chief = horovod.hvd.rank() == 0 else: self._is_chief = True if self._hvd_config.use: # Initialize zmq comms. srv_pub_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_1 + self._env.det_trial_unique_port_offset) srv_pull_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_2 + self._env.det_trial_unique_port_offset) if self._is_chief: logging.debug( f"Chief setting up server with ports {srv_pub_port}/{srv_pull_port}." ) self._chief_zmq = ipc.ZMQBroadcastServer( num_connections=self._env.experiment_config. slots_per_trial() - 1, pub_port=srv_pub_port, pull_port=srv_pull_port, ) else: chief_ip_address = self._rendezvous_info.get_ip_addresses()[0] logging.debug( f"Non-Chief {horovod.hvd.rank()} setting up comm to " f"{chief_ip_address} w/ ports " f"{srv_pub_port}/{srv_pull_port}.") self._worker_zmq = ipc.ZMQBroadcastClient( srv_pub_url=f"tcp://{chief_ip_address}:{srv_pub_port}", srv_pull_url=f"tcp://{chief_ip_address}:{srv_pull_port}", )
def _init_ipc(self, force_tcp: bool) -> None: if self.size < 2: # No broadcasting necessary. return # Global broadcast server. if self._is_chief: logging.debug(f"Chief setting up server with ports {self._pub_port}/{self._pull_port}.") self._chief_zmq = ipc.ZMQBroadcastServer( num_connections=self.size - 1, pub_url=f"tcp://*:{self._pub_port}", pull_url=f"tcp://*:{self._pull_port}", ) self._chief_zmq.safe_start(lambda: None) else: logging.debug( f"Non-Chief {self.rank} setting up comm to " f"{self._chief_ip} w/ ports " f"{self._pub_port}/{self._pull_port}." ) self._worker_zmq = ipc.ZMQBroadcastClient( srv_pub_url=f"tcp://{self._chief_ip}:{self._pub_port}", srv_pull_url=f"tcp://{self._chief_ip}:{self._pull_port}", ) self._worker_zmq.safe_start() if self.local_size < 2: # No local broadcasting necessary. return # Local broadcast server. self.tempdir = None if self._is_local_chief: pub_url = None pull_url = None if hasattr(socket, "AF_UNIX") and not force_tcp: # On systems with unix sockets, we get a slight performance bump by using them. self.tempdir = tempfile.mkdtemp(prefix="ipc") pub_url = f"ipc://{self.tempdir}/pub.sock" pull_url = f"ipc://{self.tempdir}/pull.sock" logging.debug(f"Local Chief setting up server with urls {pub_url}/{pull_url}.") self._local_chief_zmq = ipc.ZMQBroadcastServer( num_connections=self.local_size - 1, pub_url=pub_url, pull_url=pull_url, ) if pub_url is None: pub_url = f"tcp://localhost:{self._local_chief_zmq.get_pub_port()}" if pull_url is None: pull_url = f"tcp://localhost:{self._local_chief_zmq.get_pull_port()}" # Do a global allgather to initialize local clients on every node. local_chief = (self.cross_rank, pub_url, pull_url) _ = self.allgather(local_chief) self._local_chief_zmq.safe_start(lambda: None) else: # Start with the global allgather. all_local_chiefs = self.allgather(None) my_local_chief = [ x for x in all_local_chiefs if x is not None and x[0] == self.cross_rank ] assert len(my_local_chief) == 1, ( f"did not find exactly 1 local_chief for machine {self.cross_rank} " f"in {all_local_chiefs}" ) _, pub_url, pull_url = my_local_chief[0] assert isinstance(pub_url, str), f"invalid pub_url: {pub_url}" assert isinstance(pull_url, str), f"invalid pub_url: {pull_url}" logging.debug(f"Local Worker setting up server with urls {pub_url}/{pull_url}.") self._local_worker_zmq = ipc.ZMQBroadcastClient(pub_url, pull_url) self._local_worker_zmq.safe_start()