示例#1
0
    def _initialize_train_process_comm(self) -> None:
        check.true(self.hvd_config.use)

        srv_pub_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_1 +
                        self.env.det_trial_unique_port_offset)
        srv_pull_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_2 +
                         self.env.det_trial_unique_port_offset)

        if self.is_chief:
            logging.debug(
                f"Chief setting up server with ports {srv_pub_port}/{srv_pull_port}."
            )
            self.train_process_comm_chief = ipc.ZMQBroadcastServer(
                num_connections=self.env.experiment_config.slots_per_trial() -
                1,
                pub_port=srv_pub_port,
                pull_port=srv_pull_port,
            )
        else:
            chief_ip_address = self.rendezvous_info.get_ip_addresses()[0]
            logging.debug(f"Non-Chief {hvd.rank()} setting up comm to "
                          f"{chief_ip_address} w/ ports "
                          f"{srv_pub_port}/{srv_pull_port}.")
            self.train_process_comm_worker = ipc.ZMQBroadcastClient(
                srv_pub_url=f"tcp://{chief_ip_address}:{srv_pub_port}",
                srv_pull_url=f"tcp://{chief_ip_address}:{srv_pull_port}",
            )
示例#2
0
def test_broadcast_server_client() -> None:
    num_subprocs = 3

    with ipc.ZMQBroadcastServer(
            num_connections=num_subprocs) as broadcast_server:

        pub_url = f"tcp://localhost:{broadcast_server.get_pub_port()}"
        pull_url = f"tcp://localhost:{broadcast_server.get_pull_port()}"
        msgs = list(range(10))

        with SubprocGroup(
                BroadcastClientSubproc(i, num_subprocs, pub_url, pull_url,
                                       msgs)
                for i in range(num_subprocs)) as subprocs:

            def health_check() -> None:
                assert all(subproc.is_alive() for subproc in subprocs)
                for subproc in subprocs:
                    assert subproc.is_alive()

            gathered, _ = broadcast_server.gather_with_polling(health_check)
            assert all(isinstance(g, ipc.ConnectedMessage) for g in gathered)

            for msg in msgs:
                broadcast_server.broadcast(msg)
                gathered, _ = broadcast_server.gather_with_polling(
                    health_check)
                assert all(g == 2 * msg for g in gathered)
示例#3
0
    def __init__(
        self,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
        python_subprocess_entrypoint: Optional[str] = None,
    ) -> None:

        self.env = env
        self.workloads = workloads
        self.load_path = load_path
        self.rendezvous_info = rendezvous_info
        self.hvd_config = hvd_config
        self._python_subprocess_entrypoint = python_subprocess_entrypoint

        self.num_gpus = len(self.env.container_gpus)
        self.debug = self.env.experiment_config.debug_enabled()

        # Horovod will have a separate training process for each GPU.
        number_of_worker_processes = self.num_gpus if self.hvd_config.use else 1

        # Step 1: Establish the server for communicating with the subprocess.
        self.broadcast_server = ipc.ZMQBroadcastServer(num_connections=number_of_worker_processes)

        # Step 2: Configure the per-machine WorkerProcessContext.
        self._init_worker_process_env()

        self.is_chief_machine = self.rendezvous_info.get_rank() == 0
        chief_addr = self.rendezvous_info.get_ip_addresses()[0]
        chief_port = self.rendezvous_info.get_ports()[0]

        if self.is_chief_machine:
            # Step 3 (chief): Wait for any peer machines to launch sshd, then launch horovodrun.
            if self.rendezvous_info.get_size() > 1:
                with ipc.ZMQServer(ports=[chief_port], num_connections=1) as server:
                    num_peers = self.rendezvous_info.get_size() - 1
                    responses = server.barrier(num_connections=num_peers, timeout=20)
                    if len(responses) < num_peers:
                        raise AssertionError(
                            f"Chief received sshd ready signal only from {len(responses)} "
                            f"of {num_peers} machines."
                        )
                    logging.debug("Chief finished sshd barrier.")

            if self.hvd_config.use:
                self._subproc = self._launch_horovodrun()
            else:
                self._subproc = self._launch_python_subprocess()

        else:
            # Step 3 (non-chief): launch sshd, wait for it to come up, then signal to the chief.
            self._subproc = self._launch_sshd()

            self._wait_for_sshd_to_start()

            with ipc.ZMQClient(chief_addr, chief_port) as client:
                client.barrier()
示例#4
0
def test_broadcast_server_client() -> None:
    num_subprocs = 3

    with ipc.ZMQBroadcastServer(
            num_connections=num_subprocs) as broadcast_server:

        pub_url = f"tcp://localhost:{broadcast_server.get_pub_port()}"
        pull_url = f"tcp://localhost:{broadcast_server.get_pull_port()}"
        msgs = list(range(10))

        with SubprocGroup(
                BroadcastClientSubproc(i, num_subprocs, pub_url, pull_url,
                                       msgs) for i in range(num_subprocs)):
            broadcast_server.safe_start()
            for msg in msgs:
                broadcast_server.broadcast(msg)
                gathered = broadcast_server.gather()
                assert all(g == 2 * msg for g in gathered)
示例#5
0
    def __init__(
        self,
        env: det.EnvContext,
        hvd_config: horovod.HorovodContext,
        rendezvous_info: det.RendezvousInfo,
    ) -> None:
        self._env = env
        self._hvd_config = hvd_config
        self._rendezvous_info = rendezvous_info

        if self._hvd_config.use:
            self._is_chief = horovod.hvd.rank() == 0
        else:
            self._is_chief = True

        if self._hvd_config.use:
            # Initialize zmq comms.
            srv_pub_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_1 +
                            self._env.det_trial_unique_port_offset)
            srv_pull_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_2 +
                             self._env.det_trial_unique_port_offset)

            if self._is_chief:
                logging.debug(
                    f"Chief setting up server with ports {srv_pub_port}/{srv_pull_port}."
                )
                self._chief_zmq = ipc.ZMQBroadcastServer(
                    num_connections=self._env.experiment_config.
                    slots_per_trial() - 1,
                    pub_port=srv_pub_port,
                    pull_port=srv_pull_port,
                )

            else:
                chief_ip_address = self._rendezvous_info.get_ip_addresses()[0]
                logging.debug(
                    f"Non-Chief {horovod.hvd.rank()} setting up comm to "
                    f"{chief_ip_address} w/ ports "
                    f"{srv_pub_port}/{srv_pull_port}.")
                self._worker_zmq = ipc.ZMQBroadcastClient(
                    srv_pub_url=f"tcp://{chief_ip_address}:{srv_pub_port}",
                    srv_pull_url=f"tcp://{chief_ip_address}:{srv_pull_port}",
                )
    def _init_ipc(self, force_tcp: bool) -> None:
        if self.size < 2:
            # No broadcasting necessary.
            return

        # Global broadcast server.
        if self._is_chief:
            logging.debug(f"Chief setting up server with ports {self._pub_port}/{self._pull_port}.")
            self._chief_zmq = ipc.ZMQBroadcastServer(
                num_connections=self.size - 1,
                pub_url=f"tcp://*:{self._pub_port}",
                pull_url=f"tcp://*:{self._pull_port}",
            )
            self._chief_zmq.safe_start(lambda: None)

        else:
            logging.debug(
                f"Non-Chief {self.rank} setting up comm to "
                f"{self._chief_ip} w/ ports "
                f"{self._pub_port}/{self._pull_port}."
            )
            self._worker_zmq = ipc.ZMQBroadcastClient(
                srv_pub_url=f"tcp://{self._chief_ip}:{self._pub_port}",
                srv_pull_url=f"tcp://{self._chief_ip}:{self._pull_port}",
            )
            self._worker_zmq.safe_start()

        if self.local_size < 2:
            # No local broadcasting necessary.
            return

        # Local broadcast server.
        self.tempdir = None
        if self._is_local_chief:
            pub_url = None
            pull_url = None
            if hasattr(socket, "AF_UNIX") and not force_tcp:
                # On systems with unix sockets, we get a slight performance bump by using them.
                self.tempdir = tempfile.mkdtemp(prefix="ipc")
                pub_url = f"ipc://{self.tempdir}/pub.sock"
                pull_url = f"ipc://{self.tempdir}/pull.sock"

            logging.debug(f"Local Chief setting up server with urls {pub_url}/{pull_url}.")
            self._local_chief_zmq = ipc.ZMQBroadcastServer(
                num_connections=self.local_size - 1,
                pub_url=pub_url,
                pull_url=pull_url,
            )

            if pub_url is None:
                pub_url = f"tcp://localhost:{self._local_chief_zmq.get_pub_port()}"

            if pull_url is None:
                pull_url = f"tcp://localhost:{self._local_chief_zmq.get_pull_port()}"

            # Do a global allgather to initialize local clients on every node.
            local_chief = (self.cross_rank, pub_url, pull_url)
            _ = self.allgather(local_chief)
            self._local_chief_zmq.safe_start(lambda: None)

        else:
            # Start with the global allgather.
            all_local_chiefs = self.allgather(None)
            my_local_chief = [
                x for x in all_local_chiefs if x is not None and x[0] == self.cross_rank
            ]
            assert len(my_local_chief) == 1, (
                f"did not find exactly 1 local_chief for machine {self.cross_rank} "
                f"in {all_local_chiefs}"
            )
            _, pub_url, pull_url = my_local_chief[0]

            assert isinstance(pub_url, str), f"invalid pub_url: {pub_url}"
            assert isinstance(pull_url, str), f"invalid pub_url: {pull_url}"

            logging.debug(f"Local Worker setting up server with urls {pub_url}/{pull_url}.")
            self._local_worker_zmq = ipc.ZMQBroadcastClient(pub_url, pull_url)
            self._local_worker_zmq.safe_start()