예제 #1
0
def main() -> None:
    if len(sys.argv) != 2:
        print(
            "worker_process_env_path must be provided as a commandline argument",
            file=sys.stderr)
        sys.exit(1)

    # Load the worker process env.
    worker_process_env_path = pathlib.Path(sys.argv[1])
    worker_process_env = layers.WorkerProcessContext.from_file(
        worker_process_env_path)

    # Establish the connection to the ZMQBroadcastServer.
    pub_url = f"tcp://localhost:{worker_process_env.broadcast_pub_port}"
    sub_url = f"tcp://localhost:{worker_process_env.broadcast_pull_port}"
    with ipc.ZMQBroadcastClient(pub_url, sub_url) as broadcast_client:

        # Wrap the communication layer in a workload.Stream.
        subrec = layers.SubprocessReceiver(broadcast_client)

        # Compare the workloads received against the expected stream of workloads.
        expected = fake_workload_gen()
        for i, (wkld, _, resp_fn) in enumerate(iter(subrec)):
            assert wkld == next(expected)
            resp_fn({"count": i})

        assert i == NUM_FAKE_WORKLOADS
예제 #2
0
    def _initialize_train_process_comm(self) -> None:
        check.true(self.hvd_config.use)

        srv_pub_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_1 +
                        self.env.det_trial_unique_port_offset)
        srv_pull_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_2 +
                         self.env.det_trial_unique_port_offset)

        if self.is_chief:
            logging.debug(
                f"Chief setting up server with ports {srv_pub_port}/{srv_pull_port}."
            )
            self.train_process_comm_chief = ipc.ZMQBroadcastServer(
                num_connections=self.env.experiment_config.slots_per_trial() -
                1,
                pub_port=srv_pub_port,
                pull_port=srv_pull_port,
            )
        else:
            chief_ip_address = self.rendezvous_info.get_ip_addresses()[0]
            logging.debug(f"Non-Chief {hvd.rank()} setting up comm to "
                          f"{chief_ip_address} w/ ports "
                          f"{srv_pub_port}/{srv_pull_port}.")
            self.train_process_comm_worker = ipc.ZMQBroadcastClient(
                srv_pub_url=f"tcp://{chief_ip_address}:{srv_pub_port}",
                srv_pull_url=f"tcp://{chief_ip_address}:{srv_pull_port}",
            )
예제 #3
0
def main() -> None:
    if len(sys.argv) != 2:
        print(
            "worker_process_env_path must be provided as a commandline argument",
            file=sys.stderr)
        sys.exit(1)

    # Load the worker process env.
    worker_process_env_path = pathlib.Path(sys.argv[1])
    worker_process_env = layers.WorkerProcessContext.from_file(
        worker_process_env_path)

    config_logging(worker_process_env)

    if worker_process_env.env.experiment_config.debug_enabled():
        faulthandler.dump_traceback_later(30, repeat=True)

    # Establish the connection to the ZMQBroadcastServer in this container.
    pub_url = f"tcp://localhost:{worker_process_env.broadcast_pub_port}"
    sub_url = f"tcp://localhost:{worker_process_env.broadcast_pull_port}"
    with ipc.ZMQBroadcastClient(pub_url, sub_url) as broadcast_client:

        # Wrap the communication layer in a workload.Stream.
        subrec = layers.SubprocessReceiver(broadcast_client)

        controller = load.prepare_controller(
            worker_process_env.env,
            iter(subrec),
            worker_process_env.load_path,
            worker_process_env.rendezvous_info,
            worker_process_env.hvd_config,
        )
        controller.run()
예제 #4
0
 def main(self) -> None:
     with ipc.ZMQBroadcastClient(self._pub_url,
                                 self._pull_url) as broadcast_client:
         # Start the server-client communication test.
         broadcast_client.send(ipc.ConnectedMessage(process_id=0))
         for exp in self._exp_msgs:
             msg = broadcast_client.recv()
             assert msg == exp
             broadcast_client.send(2 * msg)
예제 #5
0
 def main(self) -> None:
     with ipc.ZMQBroadcastClient(self._pub_url,
                                 self._pull_url) as broadcast_client:
         # Start the server-client communication test.
         broadcast_client.safe_start()
         for exp in self._exp_msgs:
             msg = broadcast_client.recv()
             assert msg == exp
             broadcast_client.send(2 * msg)
예제 #6
0
def main() -> None:
    if len(sys.argv) != 2:
        print("worker_process_env_path must be provided as a commandline argument", file=sys.stderr)
        sys.exit(1)

    # Load the worker process env.
    worker_process_env_path = pathlib.Path(sys.argv[1])
    worker_process_env = layers.WorkerProcessContext.from_file(worker_process_env_path)

    config_logging(worker_process_env)

    # API code expects credential to be available as an environment variable
    os.environ["DET_TASK_TOKEN"] = worker_process_env.env.det_task_token

    # TODO: refactor websocket, data_layer, and profiling to to not use the cli_cert.
    master_url = (
        f"http{'s' if worker_process_env.env.use_tls else ''}://"
        f"{worker_process_env.env.master_addr}:{worker_process_env.env.master_port}"
    )
    certs.cli_cert = certs.default_load(master_url=master_url)

    if worker_process_env.env.experiment_config.debug_enabled():
        faulthandler.dump_traceback_later(30, repeat=True)

    # Establish the connection to the ZMQBroadcastServer in this container.
    pub_url = f"tcp://localhost:{worker_process_env.broadcast_pub_port}"
    sub_url = f"tcp://localhost:{worker_process_env.broadcast_pull_port}"
    with ipc.ZMQBroadcastClient(pub_url, sub_url) as broadcast_client:

        # Wrap the communication layer in a workload.Stream.
        subrec = layers.SubprocessReceiver(broadcast_client)
        workloads = iter(subrec)

        with det._catch_sys_exit():
            with det._catch_init_invalid_hp(workloads):
                controller = load.prepare_controller(
                    worker_process_env.env,
                    workloads,
                    worker_process_env.load_path,
                    worker_process_env.rendezvous_info,
                    worker_process_env.hvd_config,
                )

            try:
                controller.run()

            except Exception as e:
                broadcast_client.send_exception_message()
                raise e
예제 #7
0
    def __init__(
        self,
        env: det.EnvContext,
        hvd_config: horovod.HorovodContext,
        rendezvous_info: det.RendezvousInfo,
    ) -> None:
        self._env = env
        self._hvd_config = hvd_config
        self._rendezvous_info = rendezvous_info

        if self._hvd_config.use:
            self._is_chief = horovod.hvd.rank() == 0
        else:
            self._is_chief = True

        if self._hvd_config.use:
            # Initialize zmq comms.
            srv_pub_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_1 +
                            self._env.det_trial_unique_port_offset)
            srv_pull_port = (constants.INTER_TRAIN_PROCESS_COMM_PORT_2 +
                             self._env.det_trial_unique_port_offset)

            if self._is_chief:
                logging.debug(
                    f"Chief setting up server with ports {srv_pub_port}/{srv_pull_port}."
                )
                self._chief_zmq = ipc.ZMQBroadcastServer(
                    num_connections=self._env.experiment_config.
                    slots_per_trial() - 1,
                    pub_port=srv_pub_port,
                    pull_port=srv_pull_port,
                )

            else:
                chief_ip_address = self._rendezvous_info.get_ip_addresses()[0]
                logging.debug(
                    f"Non-Chief {horovod.hvd.rank()} setting up comm to "
                    f"{chief_ip_address} w/ ports "
                    f"{srv_pub_port}/{srv_pull_port}.")
                self._worker_zmq = ipc.ZMQBroadcastClient(
                    srv_pub_url=f"tcp://{chief_ip_address}:{srv_pub_port}",
                    srv_pull_url=f"tcp://{chief_ip_address}:{srv_pull_port}",
                )
    def _init_ipc(self, force_tcp: bool) -> None:
        if self.size < 2:
            # No broadcasting necessary.
            return

        # Global broadcast server.
        if self._is_chief:
            logging.debug(f"Chief setting up server with ports {self._pub_port}/{self._pull_port}.")
            self._chief_zmq = ipc.ZMQBroadcastServer(
                num_connections=self.size - 1,
                pub_url=f"tcp://*:{self._pub_port}",
                pull_url=f"tcp://*:{self._pull_port}",
            )
            self._chief_zmq.safe_start(lambda: None)

        else:
            logging.debug(
                f"Non-Chief {self.rank} setting up comm to "
                f"{self._chief_ip} w/ ports "
                f"{self._pub_port}/{self._pull_port}."
            )
            self._worker_zmq = ipc.ZMQBroadcastClient(
                srv_pub_url=f"tcp://{self._chief_ip}:{self._pub_port}",
                srv_pull_url=f"tcp://{self._chief_ip}:{self._pull_port}",
            )
            self._worker_zmq.safe_start()

        if self.local_size < 2:
            # No local broadcasting necessary.
            return

        # Local broadcast server.
        self.tempdir = None
        if self._is_local_chief:
            pub_url = None
            pull_url = None
            if hasattr(socket, "AF_UNIX") and not force_tcp:
                # On systems with unix sockets, we get a slight performance bump by using them.
                self.tempdir = tempfile.mkdtemp(prefix="ipc")
                pub_url = f"ipc://{self.tempdir}/pub.sock"
                pull_url = f"ipc://{self.tempdir}/pull.sock"

            logging.debug(f"Local Chief setting up server with urls {pub_url}/{pull_url}.")
            self._local_chief_zmq = ipc.ZMQBroadcastServer(
                num_connections=self.local_size - 1,
                pub_url=pub_url,
                pull_url=pull_url,
            )

            if pub_url is None:
                pub_url = f"tcp://localhost:{self._local_chief_zmq.get_pub_port()}"

            if pull_url is None:
                pull_url = f"tcp://localhost:{self._local_chief_zmq.get_pull_port()}"

            # Do a global allgather to initialize local clients on every node.
            local_chief = (self.cross_rank, pub_url, pull_url)
            _ = self.allgather(local_chief)
            self._local_chief_zmq.safe_start(lambda: None)

        else:
            # Start with the global allgather.
            all_local_chiefs = self.allgather(None)
            my_local_chief = [
                x for x in all_local_chiefs if x is not None and x[0] == self.cross_rank
            ]
            assert len(my_local_chief) == 1, (
                f"did not find exactly 1 local_chief for machine {self.cross_rank} "
                f"in {all_local_chiefs}"
            )
            _, pub_url, pull_url = my_local_chief[0]

            assert isinstance(pub_url, str), f"invalid pub_url: {pub_url}"
            assert isinstance(pull_url, str), f"invalid pub_url: {pull_url}"

            logging.debug(f"Local Worker setting up server with urls {pub_url}/{pull_url}.")
            self._local_worker_zmq = ipc.ZMQBroadcastClient(pub_url, pull_url)
            self._local_worker_zmq.safe_start()