def make_mock_cluster_info(container_addrs: List[str], container_rank: int, num_slots: int) -> det.ClusterInfo: config = utils.make_default_exp_config({}, 100, "loss", None) trial_info_mock = det.TrialInfo( trial_id=1, experiment_id=1, trial_seed=0, hparams={}, config=config, steps_completed=0, trial_run_id=0, debug=False, unique_port_offset=0, inter_node_network_interface=None, ) rendezvous_info_mock = det.RendezvousInfo(container_addrs=container_addrs, container_rank=container_rank) cluster_info_mock = det.ClusterInfo( master_url="localhost", cluster_id="clusterId", agent_id="agentId", slot_ids=list(range(num_slots)), task_id="taskId", allocation_id="allocationId", session_token="sessionToken", task_type="TRIAL", rendezvous_info=rendezvous_info_mock, trial_info=trial_info_mock, ) return cluster_info_mock
def do_rendezvous_rm_provided(sess: Session, allocation_id: str, resources_id: str) -> "det.RendezvousInfo": resp = bindings.get_AllocationRendezvousInfo(sess, allocationId=allocation_id, resourcesId=resources_id) return det.RendezvousInfo(container_addrs=list( resp.rendezvousInfo.addresses), container_rank=resp.rendezvousInfo.rank)
def _make_local_execution_env( managed_training: bool, test_mode: bool, config: Optional[Dict[str, Any]], hparams: Optional[Dict[str, Any]] = None, limit_gpus: Optional[int] = None, ) -> Tuple[det.EnvContext, det.RendezvousInfo, horovod.HorovodContext]: config = det.ExperimentConfig( _make_local_execution_exp_config(config, managed_training=managed_training, test_mode=test_mode)) hparams = hparams or api.generate_random_hparam_values( config.get("hyperparameters", {})) use_gpu, container_gpus, slot_ids = _get_gpus(limit_gpus) local_rendezvous_ports = ( f"{constants.LOCAL_RENDEZVOUS_PORT},{constants.LOCAL_RENDEZVOUS_PORT+1}" ) env = det.EnvContext( master_addr="", master_port=0, use_tls=False, master_cert_file=None, master_cert_name=None, container_id="", experiment_config=config, hparams=hparams, initial_workload=workload.train_workload(1, 1, 1, config.scheduling_unit()), latest_checkpoint=None, use_gpu=use_gpu, container_gpus=container_gpus, slot_ids=slot_ids, debug=config.debug_enabled(), workload_manager_type="", det_rendezvous_ports=local_rendezvous_ports, det_trial_unique_port_offset=0, det_trial_runner_network_interface=constants. AUTO_DETECT_TRIAL_RUNNER_NETWORK_INTERFACE, det_trial_id="", det_experiment_id="", det_cluster_id="", trial_seed=config.experiment_seed(), managed_training=managed_training, test_mode=test_mode, on_cluster=False, ) rendezvous_ports = env.rendezvous_ports() rendezvous_info = det.RendezvousInfo( addrs=[f"0.0.0.0:{rendezvous_ports[0]}"], addrs2=[f"0.0.0.0:{rendezvous_ports[1]}"], rank=0) hvd_config = horovod.HorovodContext.from_configs(env.experiment_config, rendezvous_info, env.hparams) return env, rendezvous_info, hvd_config
def _make_local_test_experiment_env( checkpoint_dir: pathlib.Path, config: Optional[Dict[str, Any]], hparams: Optional[Dict[str, Any]] = None, ) -> Tuple[det.EnvContext, workload.Stream, det.RendezvousInfo, horovod.HorovodContext]: config = det.ExperimentConfig(_make_local_test_experiment_config(config)) hparams = hparams or _generate_test_hparam_values(config) use_gpu, container_gpus, slot_ids = _get_gpus() local_rendezvous_ports = ( f"{constants.LOCAL_RENDEZVOUS_PORT},{constants.LOCAL_RENDEZVOUS_PORT+1}" ) env = det.EnvContext( master_addr="", master_port=1, container_id="test_mode", experiment_config=config, hparams=hparams, initial_workload=workload.train_workload(1, 1, 1, config.batches_per_step()), latest_checkpoint=None, use_gpu=use_gpu, container_gpus=container_gpus, slot_ids=slot_ids, debug=config.debug_enabled(), workload_manager_type="", det_rendezvous_ports=local_rendezvous_ports, det_trial_runner_network_interface=constants. AUTO_DETECT_TRIAL_RUNNER_NETWORK_INTERFACE, det_trial_id="1", det_experiment_id="1", det_cluster_id="test_mode", trial_seed=config.experiment_seed(), ) workloads = _make_test_workloads(checkpoint_dir.joinpath("checkpoint"), config) rendezvous_ports = env.rendezvous_ports() rendezvous_info = det.RendezvousInfo( addrs=[f"0.0.0.0:{rendezvous_ports[0]}"], addrs2=[f"0.0.0.0:{rendezvous_ports[1]}"], rank=0) hvd_config = horovod.HorovodContext.from_configs(env.experiment_config, rendezvous_info, env.hparams) return env, workloads, rendezvous_info, hvd_config
def check_for_rendezvous_info(self, event: Any) -> Optional[det.RendezvousInfo]: """ Wait for a message from the socket, and check if it is a det.RendezvousInfo. Raise an error if a Workload is seen, since those should only come after det.RendezvousInfo. """ if self.message_is_log_only(event): return None elif isinstance(event, lomond.events.Text): msg = simplejson.loads(event.text) if msg["type"] == "RENDEZVOUS_INFO": logging.info("Got rendezvous information: %s", msg) # If there's only one container, there's nothing to do for # rendezvous. addrs, rank = msg["addrs"], msg["rank"] addrs2 = msg["addrs2"] # The rendezvous info contains the external addresses for # all the containers, but we need to set what to actually # bind to inside this container. We just bind to the # wildcard interface, on a fixed port that matches the one # the agent is hardcoded to expose in all trial containers. # TODO(DET-916): Make number of ports configurable. rendezvous_ports = self.env.rendezvous_ports() addrs[rank] = f"0.0.0.0:{rendezvous_ports[0]}" addrs2[rank] = f"0.0.0.0:{rendezvous_ports[1]}" # TODO(ryan): remove rendezvous info as a environment variable. os.environ["DET_RENDEZVOUS_INFO"] = simplejson.dumps( {"addrs": addrs, "addrs2": addrs2, "rank": rank} ) return det.RendezvousInfo(addrs, addrs2, rank) elif msg["type"] == "RUN_WORKLOAD": raise ValueError("Received workload before rendezvous info") else: logging.warning(f"unexpected websocket event: {event}") return None
def do_rendezvous_slurm(sess: Session, allocation_id: str, resources_id: str) -> "det.RendezvousInfo": rank_str = os.environ.get("SLURM_PROCID") assert rank_str, "Unable to complete rendezvous without SLURM_PROCID" rank = int(rank_str) num_peers_str = os.environ.get("SLURM_NPROCS") assert num_peers_str, "Unable to complete rendezvous without SLURM_NPROCS" num_peers = int(num_peers_str) rendezvous_ip, resolution_error = None, None for rendezvous_iface in rendezvous_ifaces(): try: rendezvous_ip = get_ip_from_interface(rendezvous_iface) break except ValueError as e: resolution_error = e else: logging.warning( f"falling back to naive ip resolution after:\n\t{resolution_error}" ) rendezvous_ip = socket.gethostbyname(socket.gethostname()) # Note, rendezvous must be sorted in rank order. resp = bindings.post_AllocationAllGather( sess, allocationId=allocation_id, body=bindings.v1AllocationAllGatherRequest( allocationId=allocation_id, requestUuid=str(uuid.uuid4()), numPeers=num_peers, data={ "rank": rank, "rendezvous_ip": rendezvous_ip, }, ), ) addrs = [ d["rendezvous_ip"] for d in sorted(resp.data, key=lambda d: int(d["rank"])) ] return det.RendezvousInfo(container_addrs=addrs, container_rank=rank)
def make_default_rendezvous_info() -> det.RendezvousInfo: return det.RendezvousInfo( addrs=["127.0.0.1:1750"], addrs2=[f"127.0.0.1:{constants.LOCAL_RENDEZVOUS_PORT}"], rank=0)
def test_distributed_context(cross_size: int, local_size: int, force_tcp: bool) -> None: size = cross_size * local_size # Generate one rendezvous_info per node. rendezvous_info = [ det.RendezvousInfo( addrs=["localhost:12345"] * cross_size, rank=i, ) for i in range(cross_size) ] def do_parallel(fn: Callable) -> List: """ Run the same function on one-thread-per-rank, assert there were no exceptions, and return the results from each rank. """ results = [None] * size # type: List errors = [None] * size # type: List threads = [] for cross_rank, local_rank in itertools.product( range(cross_size), range(local_size)): rank = cross_rank * local_size + local_rank def _fn(rank: int, cross_rank: int, local_rank: int) -> None: try: results[rank] = fn(rank, cross_rank, local_rank) except Exception: errors[rank] = traceback.format_exc() raise threads.append( threading.Thread(target=_fn, args=(rank, cross_rank, local_rank))) for thread in threads: thread.start() for thread in threads: thread.join() assert errors == [None] * size, "not all threads exited without error" return results # Create all of the DistributedContexts. def make_distributed_context(rank: int, cross_rank: int, local_rank: int) -> Any: rank_info = det.RankInfo( rank=cross_rank * local_size + local_rank, size=cross_size * local_size, local_rank=local_rank, local_size=local_size, cross_rank=cross_rank, cross_size=cross_size, ) return det.DistributedContext( rank_info=rank_info, rendezvous_info=rendezvous_info[cross_rank], unique_port_offset=0, force_tcp=force_tcp, ) contexts = do_parallel(make_distributed_context) # Perform a broadcast. results = do_parallel( lambda rank, _, __: contexts[rank]._zmq_broadcast(rank)) assert results == [0] * size, "not all threads ran broadcast correctly" # Perform a local broadcast. results = do_parallel( lambda rank, _, __: contexts[rank]._zmq_broadcast_local(rank)) expect = [rank - (rank % local_size) for rank in range(size)] # type: Any assert results == expect, "not all threads ran broadcast_local correctly" # Perform a gather. results = do_parallel( lambda rank, _, __: set(contexts[rank]._zmq_gather(rank) or [])) chief = set(range(size)) expect = [set(range(size)) if rank == 0 else set() for rank in range(size)] assert results == [ chief ] + [set()] * (size - 1), "not all threads ran gather correctly" # Perform a local gather. results = do_parallel( lambda rank, _, __: set(contexts[rank]._zmq_gather_local(rank) or [])) expect = [ set(range(rank, rank + local_size)) if rank % local_size == 0 else set() for rank in range(size) ] assert results == expect, "not all threads ran gather correctly" # Perform an allgather. results = do_parallel( lambda rank, _, __: set(contexts[rank]._zmq_allgather(rank))) expect = set(range(size)) assert results == [expect ] * size, "not all threads ran allgather correctly" # Perform a local allgather. results = do_parallel( lambda rank, _, __: set(contexts[rank]._zmq_allgather_local(rank))) expect = [ set(range(cross_rank * local_size, (cross_rank + 1) * local_size)) for cross_rank, _ in itertools.product(range(cross_size), range(local_size)) ] assert results == expect, "not all threads ran allgather_local correctly" # Close all contexts. for context in contexts: context.close()