コード例 #1
0
    def __init__(
        self,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
        python_subprocess_entrypoint: Optional[str] = None,
    ) -> None:

        self.env = env
        self.workloads = workloads
        self.load_path = load_path
        self.rendezvous_info = rendezvous_info
        self.hvd_config = hvd_config
        self._python_subprocess_entrypoint = python_subprocess_entrypoint

        self.num_gpus = len(self.env.container_gpus)
        self.debug = self.env.experiment_config.debug_enabled()

        # Horovod will have a separate training process for each GPU.
        number_of_worker_processes = self.num_gpus if self.hvd_config.use else 1

        # Step 1: Establish the server for communicating with the subprocess.
        self.broadcast_server = ipc.ZMQBroadcastServer(num_connections=number_of_worker_processes)

        # Step 2: Configure the per-machine WorkerProcessContext.
        self._init_worker_process_env()

        self.is_chief_machine = self.rendezvous_info.get_rank() == 0
        chief_addr = self.rendezvous_info.get_ip_addresses()[0]
        chief_port = self.rendezvous_info.get_ports()[0]

        if self.is_chief_machine:
            # Step 3 (chief): Wait for any peer machines to launch sshd, then launch horovodrun.
            if self.rendezvous_info.get_size() > 1:
                with ipc.ZMQServer(ports=[chief_port], num_connections=1) as server:
                    num_peers = self.rendezvous_info.get_size() - 1
                    responses = server.barrier(num_connections=num_peers, timeout=20)
                    if len(responses) < num_peers:
                        raise AssertionError(
                            f"Chief received sshd ready signal only from {len(responses)} "
                            f"of {num_peers} machines."
                        )
                    logging.debug("Chief finished sshd barrier.")

            if self.hvd_config.use:
                self._subproc = self._launch_horovodrun()
            else:
                self._subproc = self._launch_python_subprocess()

        else:
            # Step 3 (non-chief): launch sshd, wait for it to come up, then signal to the chief.
            self._subproc = self._launch_sshd()

            self._wait_for_sshd_to_start()

            with ipc.ZMQClient(chief_addr, chief_port) as client:
                client.barrier()
コード例 #2
0
 def _initialize_train_process_comm(self) -> None:
     check.true(self.hvd_config.use)
     if self.is_chief:
         logging.debug(f"Chief {hvd.rank()} setting up server with "
                       f"port {constants.INTER_TRAIN_PROCESS_COMM_PORT}.")
         self.train_process_comm_chief = ipc.ZMQServer(
             ports=[constants.INTER_TRAIN_PROCESS_COMM_PORT],
             num_connections=1)
     else:
         chief_ip_address = self.rendezvous_info.get_ip_addresses()[0]
         logging.debug(
             f"Non-Chief {hvd.rank()} setting up comm to "
             f"{chief_ip_address} w/ port {constants.INTER_TRAIN_PROCESS_COMM_PORT}."
         )
         self.train_process_comm_worker = ipc.ZMQClient(
             ip_address=chief_ip_address,
             port=constants.INTER_TRAIN_PROCESS_COMM_PORT)
コード例 #3
0
def test_zmq_server_client() -> None:
    server = ipc.ZMQServer(num_connections=1,
                           ports=None,
                           port_range=(1000, 65535))
    assert len(server.get_ports()) == 1
    port = server.get_ports()[0]
    assert 1000 <= port <= 65535

    client = ipc.ZMQClient(ip_address="localhost", port=port)

    client_object = {"DeterminedAI": "Great", "det": "Fantastic", 12345: -100}
    client.send(client_object)
    server_object = server.receive_blocking(send_rank=0)
    assert server_object == client_object

    server_object["DeterminedAI"] = "VeryGreat"
    server.send(server_object)
    client_object = client.receive()
    assert server_object == client_object