def main() -> None: if len(sys.argv) != 2: print( "worker_process_env_path must be provided as a commandline argument", file=sys.stderr) sys.exit(1) # Load the worker process env. worker_process_env_path = pathlib.Path(sys.argv[1]) worker_process_env = layers.WorkerProcessContext.from_file( worker_process_env_path) # Establish the connection to the ZMQBroadcastServer. pub_url = f"tcp://localhost:{worker_process_env.broadcast_pub_port}" sub_url = f"tcp://localhost:{worker_process_env.broadcast_pull_port}" with ipc.ZMQBroadcastClient(pub_url, sub_url) as broadcast_client: # Wrap the communication layer in a workload.Stream. subrec = layers.SubprocessReceiver(broadcast_client) # Compare the workloads received against the expected stream of workloads. expected = fake_workload_gen() for i, (wkld, _, resp_fn) in enumerate(iter(subrec)): assert wkld == next(expected) resp_fn({"count": i}) assert i == NUM_FAKE_WORKLOADS
def _initialize_train_process_comm(self) -> None: check.true(self.hvd_config.use) srv_pub_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_1 + self.env.det_trial_unique_port_offset) srv_pull_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_2 + self.env.det_trial_unique_port_offset) if self.is_chief: logging.debug( f"Chief setting up server with ports {srv_pub_port}/{srv_pull_port}." ) self.train_process_comm_chief = ipc.ZMQBroadcastServer( num_connections=self.env.experiment_config.slots_per_trial() - 1, pub_port=srv_pub_port, pull_port=srv_pull_port, ) else: chief_ip_address = self.rendezvous_info.get_ip_addresses()[0] logging.debug(f"Non-Chief {hvd.rank()} setting up comm to " f"{chief_ip_address} w/ ports " f"{srv_pub_port}/{srv_pull_port}.") self.train_process_comm_worker = ipc.ZMQBroadcastClient( srv_pub_url=f"tcp://{chief_ip_address}:{srv_pub_port}", srv_pull_url=f"tcp://{chief_ip_address}:{srv_pull_port}", )
def main() -> None: if len(sys.argv) != 2: print( "worker_process_env_path must be provided as a commandline argument", file=sys.stderr) sys.exit(1) # Load the worker process env. worker_process_env_path = pathlib.Path(sys.argv[1]) worker_process_env = layers.WorkerProcessContext.from_file( worker_process_env_path) config_logging(worker_process_env) if worker_process_env.env.experiment_config.debug_enabled(): faulthandler.dump_traceback_later(30, repeat=True) # Establish the connection to the ZMQBroadcastServer in this container. pub_url = f"tcp://localhost:{worker_process_env.broadcast_pub_port}" sub_url = f"tcp://localhost:{worker_process_env.broadcast_pull_port}" with ipc.ZMQBroadcastClient(pub_url, sub_url) as broadcast_client: # Wrap the communication layer in a workload.Stream. subrec = layers.SubprocessReceiver(broadcast_client) controller = load.prepare_controller( worker_process_env.env, iter(subrec), worker_process_env.load_path, worker_process_env.rendezvous_info, worker_process_env.hvd_config, ) controller.run()
def main(self) -> None: with ipc.ZMQBroadcastClient(self._pub_url, self._pull_url) as broadcast_client: # Start the server-client communication test. broadcast_client.send(ipc.ConnectedMessage(process_id=0)) for exp in self._exp_msgs: msg = broadcast_client.recv() assert msg == exp broadcast_client.send(2 * msg)
def main(self) -> None: with ipc.ZMQBroadcastClient(self._pub_url, self._pull_url) as broadcast_client: # Start the server-client communication test. broadcast_client.safe_start() for exp in self._exp_msgs: msg = broadcast_client.recv() assert msg == exp broadcast_client.send(2 * msg)
def main() -> None: if len(sys.argv) != 2: print("worker_process_env_path must be provided as a commandline argument", file=sys.stderr) sys.exit(1) # Load the worker process env. worker_process_env_path = pathlib.Path(sys.argv[1]) worker_process_env = layers.WorkerProcessContext.from_file(worker_process_env_path) config_logging(worker_process_env) # API code expects credential to be available as an environment variable os.environ["DET_TASK_TOKEN"] = worker_process_env.env.det_task_token # TODO: refactor websocket, data_layer, and profiling to to not use the cli_cert. master_url = ( f"http{'s' if worker_process_env.env.use_tls else ''}://" f"{worker_process_env.env.master_addr}:{worker_process_env.env.master_port}" ) certs.cli_cert = certs.default_load(master_url=master_url) if worker_process_env.env.experiment_config.debug_enabled(): faulthandler.dump_traceback_later(30, repeat=True) # Establish the connection to the ZMQBroadcastServer in this container. pub_url = f"tcp://localhost:{worker_process_env.broadcast_pub_port}" sub_url = f"tcp://localhost:{worker_process_env.broadcast_pull_port}" with ipc.ZMQBroadcastClient(pub_url, sub_url) as broadcast_client: # Wrap the communication layer in a workload.Stream. subrec = layers.SubprocessReceiver(broadcast_client) workloads = iter(subrec) with det._catch_sys_exit(): with det._catch_init_invalid_hp(workloads): controller = load.prepare_controller( worker_process_env.env, workloads, worker_process_env.load_path, worker_process_env.rendezvous_info, worker_process_env.hvd_config, ) try: controller.run() except Exception as e: broadcast_client.send_exception_message() raise e
def __init__( self, env: det.EnvContext, hvd_config: horovod.HorovodContext, rendezvous_info: det.RendezvousInfo, ) -> None: self._env = env self._hvd_config = hvd_config self._rendezvous_info = rendezvous_info if self._hvd_config.use: self._is_chief = horovod.hvd.rank() == 0 else: self._is_chief = True if self._hvd_config.use: # Initialize zmq comms. srv_pub_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_1 + self._env.det_trial_unique_port_offset) srv_pull_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_2 + self._env.det_trial_unique_port_offset) if self._is_chief: logging.debug( f"Chief setting up server with ports {srv_pub_port}/{srv_pull_port}." ) self._chief_zmq = ipc.ZMQBroadcastServer( num_connections=self._env.experiment_config. slots_per_trial() - 1, pub_port=srv_pub_port, pull_port=srv_pull_port, ) else: chief_ip_address = self._rendezvous_info.get_ip_addresses()[0] logging.debug( f"Non-Chief {horovod.hvd.rank()} setting up comm to " f"{chief_ip_address} w/ ports " f"{srv_pub_port}/{srv_pull_port}.") self._worker_zmq = ipc.ZMQBroadcastClient( srv_pub_url=f"tcp://{chief_ip_address}:{srv_pub_port}", srv_pull_url=f"tcp://{chief_ip_address}:{srv_pull_port}", )
def _init_ipc(self, force_tcp: bool) -> None: if self.size < 2: # No broadcasting necessary. return # Global broadcast server. if self._is_chief: logging.debug(f"Chief setting up server with ports {self._pub_port}/{self._pull_port}.") self._chief_zmq = ipc.ZMQBroadcastServer( num_connections=self.size - 1, pub_url=f"tcp://*:{self._pub_port}", pull_url=f"tcp://*:{self._pull_port}", ) self._chief_zmq.safe_start(lambda: None) else: logging.debug( f"Non-Chief {self.rank} setting up comm to " f"{self._chief_ip} w/ ports " f"{self._pub_port}/{self._pull_port}." ) self._worker_zmq = ipc.ZMQBroadcastClient( srv_pub_url=f"tcp://{self._chief_ip}:{self._pub_port}", srv_pull_url=f"tcp://{self._chief_ip}:{self._pull_port}", ) self._worker_zmq.safe_start() if self.local_size < 2: # No local broadcasting necessary. return # Local broadcast server. self.tempdir = None if self._is_local_chief: pub_url = None pull_url = None if hasattr(socket, "AF_UNIX") and not force_tcp: # On systems with unix sockets, we get a slight performance bump by using them. self.tempdir = tempfile.mkdtemp(prefix="ipc") pub_url = f"ipc://{self.tempdir}/pub.sock" pull_url = f"ipc://{self.tempdir}/pull.sock" logging.debug(f"Local Chief setting up server with urls {pub_url}/{pull_url}.") self._local_chief_zmq = ipc.ZMQBroadcastServer( num_connections=self.local_size - 1, pub_url=pub_url, pull_url=pull_url, ) if pub_url is None: pub_url = f"tcp://localhost:{self._local_chief_zmq.get_pub_port()}" if pull_url is None: pull_url = f"tcp://localhost:{self._local_chief_zmq.get_pull_port()}" # Do a global allgather to initialize local clients on every node. local_chief = (self.cross_rank, pub_url, pull_url) _ = self.allgather(local_chief) self._local_chief_zmq.safe_start(lambda: None) else: # Start with the global allgather. all_local_chiefs = self.allgather(None) my_local_chief = [ x for x in all_local_chiefs if x is not None and x[0] == self.cross_rank ] assert len(my_local_chief) == 1, ( f"did not find exactly 1 local_chief for machine {self.cross_rank} " f"in {all_local_chiefs}" ) _, pub_url, pull_url = my_local_chief[0] assert isinstance(pub_url, str), f"invalid pub_url: {pub_url}" assert isinstance(pull_url, str), f"invalid pub_url: {pull_url}" logging.debug(f"Local Worker setting up server with urls {pub_url}/{pull_url}.") self._local_worker_zmq = ipc.ZMQBroadcastClient(pub_url, pull_url) self._local_worker_zmq.safe_start()