def build_and_run_training_pipeline(env: det.EnvContext) -> None: # Create the socket manager. The socket manager will connect to the master and read messages # until it receives the rendezvous_info. # # TODO(ryan): Pull profiler hooks out of SocketManager and into their own layer. with layers.SocketManager(env) as socket_mgr: # Create the storage manager. This is used to download the initial checkpoint here in # build_training_pipeline and also used by the workload manager to create and store # checkpoints during training. storage_mgr = storage.build(env.experiment_config["checkpoint_storage"]) [tensorboard_mgr, tensorboard_writer] = load.prepare_tensorboard(env) # Create the workload manager. The workload manager will receive workloads from the # socket_mgr, and augment them with some additional arguments. Additionally, the # workload manager is responsible for some generic workload hooks for things like timing # workloads, preparing checkpoints, and uploading completed checkpoints. Finally, the # workload manager does some sanity checks on response messages that originate from the # trial. # # TODO(ryan): Refactor WorkloadManager into separate layers that do each separate task. workload_mgr = layers.build_workload_manager( env, iter(socket_mgr), socket_mgr.get_rendezvous_info(), storage_mgr, tensorboard_mgr, tensorboard_writer, ) hvd_config = horovod.HorovodContext.from_configs( env.experiment_config, socket_mgr.get_rendezvous_info(), env.hparams ) logging.info(f"Horovod config: {hvd_config.__dict__}.") # Load the checkpoint, if necessary. Any possible sinks to this pipeline will need access # to this checkpoint. with maybe_load_checkpoint(storage_mgr, env.latest_checkpoint) as load_path: # Horovod distributed training is done inside subprocesses. if hvd_config.use: subproc = layers.SubprocessLauncher( env, iter(workload_mgr), load_path, socket_mgr.get_rendezvous_info(), hvd_config ) subproc.run() else: if env.experiment_config.debug_enabled(): faulthandler.dump_traceback_later(30, repeat=True) controller = load.prepare_controller( env, iter(workload_mgr), load_path, socket_mgr.get_rendezvous_info(), hvd_config, ) controller.run()
def test_subprocess_launcher_receiver() -> None: env = utils.make_default_env_context(hparams={"global_batch_size": 1}) rendezvous_info = utils.make_default_rendezvous_info() hvd_config = utils.make_default_hvd_config() def make_workloads() -> workload.Stream: interceptor = workload.WorkloadResponseInterceptor() for i, wkld in enumerate(fake_subprocess_receiver.fake_workload_gen()): yield from interceptor.send(wkld, []) assert interceptor.metrics_result() == {"count": i} subproc = layers.SubprocessLauncher( env=env, workloads=make_workloads(), load_path=None, rendezvous_info=rendezvous_info, hvd_config=hvd_config, python_subprocess_entrypoint="tests.fixtures.fake_subprocess_receiver", ) subproc.run()