def main() -> None: if len(sys.argv) != 2: print("worker_process_env_path must be provided as a commandline argument", file=sys.stderr) sys.exit(1) # Load the worker process env. worker_process_env_path = pathlib.Path(sys.argv[1]) worker_process_env = layers.WorkerProcessContext.from_file(worker_process_env_path) config_logging(worker_process_env) # API code expects credential to be available as an environment variable os.environ["DET_TASK_TOKEN"] = worker_process_env.env.det_task_token # TODO: refactor websocket, data_layer, and profiling to to not use the cli_cert. master_url = ( f"http{'s' if worker_process_env.env.use_tls else ''}://" f"{worker_process_env.env.master_addr}:{worker_process_env.env.master_port}" ) certs.cli_cert = certs.default_load(master_url=master_url) if worker_process_env.env.experiment_config.debug_enabled(): faulthandler.dump_traceback_later(30, repeat=True) # Establish the connection to the ZMQBroadcastServer in this container. pub_url = f"tcp://localhost:{worker_process_env.broadcast_pub_port}" sub_url = f"tcp://localhost:{worker_process_env.broadcast_pull_port}" with ipc.ZMQBroadcastClient(pub_url, sub_url) as broadcast_client: # Wrap the communication layer in a workload.Stream. subrec = layers.SubprocessReceiver(broadcast_client) workloads = iter(subrec) with det._catch_sys_exit(): with det._catch_init_invalid_hp(workloads): controller = load.prepare_controller( worker_process_env.env, workloads, worker_process_env.load_path, worker_process_env.rendezvous_info, worker_process_env.hvd_config, ) try: controller.run() except Exception as e: broadcast_client.send_exception_message() raise e
def main() -> None: if len(sys.argv) != 2: print( "worker_process_env_path must be provided as a commandline argument", file=sys.stderr) sys.exit(1) # Load the worker process env. worker_process_env_path = pathlib.Path(sys.argv[1]) worker_process_env = layers.WorkerProcessContext.from_file( worker_process_env_path) config_logging(worker_process_env) if worker_process_env.env.experiment_config.debug_enabled(): faulthandler.dump_traceback_later(30, repeat=True) # Establish the connection to the ZMQBroadcastServer in this container. pub_url = f"tcp://localhost:{worker_process_env.broadcast_pub_port}" sub_url = f"tcp://localhost:{worker_process_env.broadcast_pull_port}" with ipc.ZMQBroadcastClient(pub_url, sub_url) as broadcast_client: # Wrap the communication layer in a workload.Stream. subrec = layers.SubprocessReceiver(broadcast_client) workloads = iter(subrec) with det._catch_sys_exit(): with det._catch_init_invalid_hp(workloads): controller = load.prepare_controller( worker_process_env.env, workloads, worker_process_env.load_path, worker_process_env.rendezvous_info, worker_process_env.hvd_config, ) try: controller.run() except Exception as e: broadcast_client.send_exception_message() raise e
def build_and_run_training_pipeline(env: det.EnvContext) -> None: # Create the socket manager. The socket manager will connect to the master and read messages # until it receives the rendezvous_info. # # TODO(ryan): Pull profiler hooks out of SocketManager and into their own layer. with layers.SocketManager(env) as socket_mgr: # Create the storage manager. This is used to download the initial checkpoint here in # build_training_pipeline and also used by the workload manager to create and store # checkpoints during training. storage_mgr = storage.build( env.experiment_config["checkpoint_storage"], container_path=constants.SHARED_FS_CONTAINER_PATH, ) [tensorboard_mgr, tensorboard_writer ] = load.prepare_tensorboard(env, constants.SHARED_FS_CONTAINER_PATH) # Create the workload manager. The workload manager will receive workloads from the # socket_mgr, and augment them with some additional arguments. Additionally, the # workload manager is responsible for some generic workload hooks for things like timing # workloads, preparing checkpoints, and uploading completed checkpoints. Finally, the # workload manager does some sanity checks on response messages that originate from the # trial. # # TODO(ryan): Refactor WorkloadManager into separate layers that do each separate task. workload_mgr = layers.build_workload_manager( env, iter(socket_mgr), socket_mgr.get_rendezvous_info(), storage_mgr, tensorboard_mgr, tensorboard_writer, ) workloads = iter(workload_mgr) hvd_config = horovod.HorovodContext.from_configs( env.experiment_config, socket_mgr.get_rendezvous_info(), env.hparams) logging.info(f"Horovod config: {hvd_config.__dict__}.") # Load the checkpoint, if necessary. Any possible sinks to this pipeline will need access # to this checkpoint. with maybe_load_checkpoint(storage_mgr, env.latest_checkpoint) as load_path: # Horovod distributed training is done inside subprocesses. if hvd_config.use: subproc = layers.SubprocessLauncher( env, workloads, load_path, socket_mgr.get_rendezvous_info(), hvd_config) subproc.run() else: if env.experiment_config.debug_enabled(): faulthandler.dump_traceback_later(30, repeat=True) with det._catch_sys_exit(): with det._catch_init_invalid_hp(workloads): controller = load.prepare_controller( env, workloads, load_path, socket_mgr.get_rendezvous_info(), hvd_config, ) controller.run()