Example #1
0
def make_mock_cluster_info(container_addrs: List[str], container_rank: int,
                           num_slots: int) -> det.ClusterInfo:
    config = utils.make_default_exp_config({}, 100, "loss", None)
    trial_info_mock = det.TrialInfo(
        trial_id=1,
        experiment_id=1,
        trial_seed=0,
        hparams={},
        config=config,
        steps_completed=0,
        trial_run_id=0,
        debug=False,
        unique_port_offset=0,
        inter_node_network_interface=None,
    )
    rendezvous_info_mock = det.RendezvousInfo(container_addrs=container_addrs,
                                              container_rank=container_rank)
    cluster_info_mock = det.ClusterInfo(
        master_url="localhost",
        cluster_id="clusterId",
        agent_id="agentId",
        slot_ids=list(range(num_slots)),
        task_id="taskId",
        allocation_id="allocationId",
        session_token="sessionToken",
        task_type="TRIAL",
        rendezvous_info=rendezvous_info_mock,
        trial_info=trial_info_mock,
    )
    return cluster_info_mock
Example #2
0
def do_rendezvous_rm_provided(sess: Session, allocation_id: str,
                              resources_id: str) -> "det.RendezvousInfo":
    resp = bindings.get_AllocationRendezvousInfo(sess,
                                                 allocationId=allocation_id,
                                                 resourcesId=resources_id)
    return det.RendezvousInfo(container_addrs=list(
        resp.rendezvousInfo.addresses),
                              container_rank=resp.rendezvousInfo.rank)
Example #3
0
def _make_local_execution_env(
    managed_training: bool,
    test_mode: bool,
    config: Optional[Dict[str, Any]],
    hparams: Optional[Dict[str, Any]] = None,
    limit_gpus: Optional[int] = None,
) -> Tuple[det.EnvContext, det.RendezvousInfo, horovod.HorovodContext]:
    config = det.ExperimentConfig(
        _make_local_execution_exp_config(config,
                                         managed_training=managed_training,
                                         test_mode=test_mode))
    hparams = hparams or api.generate_random_hparam_values(
        config.get("hyperparameters", {}))
    use_gpu, container_gpus, slot_ids = _get_gpus(limit_gpus)
    local_rendezvous_ports = (
        f"{constants.LOCAL_RENDEZVOUS_PORT},{constants.LOCAL_RENDEZVOUS_PORT+1}"
    )

    env = det.EnvContext(
        master_addr="",
        master_port=0,
        use_tls=False,
        master_cert_file=None,
        master_cert_name=None,
        container_id="",
        experiment_config=config,
        hparams=hparams,
        initial_workload=workload.train_workload(1, 1, 1,
                                                 config.scheduling_unit()),
        latest_checkpoint=None,
        use_gpu=use_gpu,
        container_gpus=container_gpus,
        slot_ids=slot_ids,
        debug=config.debug_enabled(),
        workload_manager_type="",
        det_rendezvous_ports=local_rendezvous_ports,
        det_trial_unique_port_offset=0,
        det_trial_runner_network_interface=constants.
        AUTO_DETECT_TRIAL_RUNNER_NETWORK_INTERFACE,
        det_trial_id="",
        det_experiment_id="",
        det_cluster_id="",
        trial_seed=config.experiment_seed(),
        managed_training=managed_training,
        test_mode=test_mode,
        on_cluster=False,
    )
    rendezvous_ports = env.rendezvous_ports()
    rendezvous_info = det.RendezvousInfo(
        addrs=[f"0.0.0.0:{rendezvous_ports[0]}"],
        addrs2=[f"0.0.0.0:{rendezvous_ports[1]}"],
        rank=0)
    hvd_config = horovod.HorovodContext.from_configs(env.experiment_config,
                                                     rendezvous_info,
                                                     env.hparams)

    return env, rendezvous_info, hvd_config
Example #4
0
def _make_local_test_experiment_env(
    checkpoint_dir: pathlib.Path,
    config: Optional[Dict[str, Any]],
    hparams: Optional[Dict[str, Any]] = None,
) -> Tuple[det.EnvContext, workload.Stream, det.RendezvousInfo,
           horovod.HorovodContext]:
    config = det.ExperimentConfig(_make_local_test_experiment_config(config))
    hparams = hparams or _generate_test_hparam_values(config)
    use_gpu, container_gpus, slot_ids = _get_gpus()
    local_rendezvous_ports = (
        f"{constants.LOCAL_RENDEZVOUS_PORT},{constants.LOCAL_RENDEZVOUS_PORT+1}"
    )

    env = det.EnvContext(
        master_addr="",
        master_port=1,
        container_id="test_mode",
        experiment_config=config,
        hparams=hparams,
        initial_workload=workload.train_workload(1, 1, 1,
                                                 config.batches_per_step()),
        latest_checkpoint=None,
        use_gpu=use_gpu,
        container_gpus=container_gpus,
        slot_ids=slot_ids,
        debug=config.debug_enabled(),
        workload_manager_type="",
        det_rendezvous_ports=local_rendezvous_ports,
        det_trial_runner_network_interface=constants.
        AUTO_DETECT_TRIAL_RUNNER_NETWORK_INTERFACE,
        det_trial_id="1",
        det_experiment_id="1",
        det_cluster_id="test_mode",
        trial_seed=config.experiment_seed(),
    )
    workloads = _make_test_workloads(checkpoint_dir.joinpath("checkpoint"),
                                     config)
    rendezvous_ports = env.rendezvous_ports()
    rendezvous_info = det.RendezvousInfo(
        addrs=[f"0.0.0.0:{rendezvous_ports[0]}"],
        addrs2=[f"0.0.0.0:{rendezvous_ports[1]}"],
        rank=0)
    hvd_config = horovod.HorovodContext.from_configs(env.experiment_config,
                                                     rendezvous_info,
                                                     env.hparams)

    return env, workloads, rendezvous_info, hvd_config
Example #5
0
    def check_for_rendezvous_info(self, event: Any) -> Optional[det.RendezvousInfo]:
        """
        Wait for a message from the socket, and check if it is a det.RendezvousInfo.

        Raise an error if a Workload is seen, since those should only come after
        det.RendezvousInfo.
        """

        if self.message_is_log_only(event):
            return None
        elif isinstance(event, lomond.events.Text):
            msg = simplejson.loads(event.text)

            if msg["type"] == "RENDEZVOUS_INFO":
                logging.info("Got rendezvous information: %s", msg)

                # If there's only one container, there's nothing to do for
                # rendezvous.
                addrs, rank = msg["addrs"], msg["rank"]
                addrs2 = msg["addrs2"]

                # The rendezvous info contains the external addresses for
                # all the containers, but we need to set what to actually
                # bind to inside this container. We just bind to the
                # wildcard interface, on a fixed port that matches the one
                # the agent is hardcoded to expose in all trial containers.
                # TODO(DET-916): Make number of ports configurable.
                rendezvous_ports = self.env.rendezvous_ports()
                addrs[rank] = f"0.0.0.0:{rendezvous_ports[0]}"
                addrs2[rank] = f"0.0.0.0:{rendezvous_ports[1]}"

                # TODO(ryan): remove rendezvous info as a environment variable.
                os.environ["DET_RENDEZVOUS_INFO"] = simplejson.dumps(
                    {"addrs": addrs, "addrs2": addrs2, "rank": rank}
                )

                return det.RendezvousInfo(addrs, addrs2, rank)

            elif msg["type"] == "RUN_WORKLOAD":
                raise ValueError("Received workload before rendezvous info")
        else:
            logging.warning(f"unexpected websocket event: {event}")

        return None
Example #6
0
def do_rendezvous_slurm(sess: Session, allocation_id: str,
                        resources_id: str) -> "det.RendezvousInfo":
    rank_str = os.environ.get("SLURM_PROCID")
    assert rank_str, "Unable to complete rendezvous without SLURM_PROCID"
    rank = int(rank_str)

    num_peers_str = os.environ.get("SLURM_NPROCS")
    assert num_peers_str, "Unable to complete rendezvous without SLURM_NPROCS"
    num_peers = int(num_peers_str)

    rendezvous_ip, resolution_error = None, None
    for rendezvous_iface in rendezvous_ifaces():
        try:
            rendezvous_ip = get_ip_from_interface(rendezvous_iface)
            break
        except ValueError as e:
            resolution_error = e
    else:
        logging.warning(
            f"falling back to naive ip resolution after:\n\t{resolution_error}"
        )
        rendezvous_ip = socket.gethostbyname(socket.gethostname())

    # Note, rendezvous must be sorted in rank order.
    resp = bindings.post_AllocationAllGather(
        sess,
        allocationId=allocation_id,
        body=bindings.v1AllocationAllGatherRequest(
            allocationId=allocation_id,
            requestUuid=str(uuid.uuid4()),
            numPeers=num_peers,
            data={
                "rank": rank,
                "rendezvous_ip": rendezvous_ip,
            },
        ),
    )
    addrs = [
        d["rendezvous_ip"]
        for d in sorted(resp.data, key=lambda d: int(d["rank"]))
    ]
    return det.RendezvousInfo(container_addrs=addrs, container_rank=rank)
Example #7
0
def make_default_rendezvous_info() -> det.RendezvousInfo:
    return det.RendezvousInfo(
        addrs=["127.0.0.1:1750"],
        addrs2=[f"127.0.0.1:{constants.LOCAL_RENDEZVOUS_PORT}"],
        rank=0)
Example #8
0
def test_distributed_context(cross_size: int, local_size: int,
                             force_tcp: bool) -> None:
    size = cross_size * local_size
    # Generate one rendezvous_info per node.
    rendezvous_info = [
        det.RendezvousInfo(
            addrs=["localhost:12345"] * cross_size,
            rank=i,
        ) for i in range(cross_size)
    ]

    def do_parallel(fn: Callable) -> List:
        """
        Run the same function on one-thread-per-rank, assert there were no exceptions, and return
        the results from each rank.
        """
        results = [None] * size  # type: List
        errors = [None] * size  # type: List
        threads = []

        for cross_rank, local_rank in itertools.product(
                range(cross_size), range(local_size)):
            rank = cross_rank * local_size + local_rank

            def _fn(rank: int, cross_rank: int, local_rank: int) -> None:
                try:
                    results[rank] = fn(rank, cross_rank, local_rank)
                except Exception:
                    errors[rank] = traceback.format_exc()
                    raise

            threads.append(
                threading.Thread(target=_fn,
                                 args=(rank, cross_rank, local_rank)))

        for thread in threads:
            thread.start()

        for thread in threads:
            thread.join()

        assert errors == [None] * size, "not all threads exited without error"

        return results

    # Create all of the DistributedContexts.
    def make_distributed_context(rank: int, cross_rank: int,
                                 local_rank: int) -> Any:
        rank_info = det.RankInfo(
            rank=cross_rank * local_size + local_rank,
            size=cross_size * local_size,
            local_rank=local_rank,
            local_size=local_size,
            cross_rank=cross_rank,
            cross_size=cross_size,
        )
        return det.DistributedContext(
            rank_info=rank_info,
            rendezvous_info=rendezvous_info[cross_rank],
            unique_port_offset=0,
            force_tcp=force_tcp,
        )

    contexts = do_parallel(make_distributed_context)

    # Perform a broadcast.
    results = do_parallel(
        lambda rank, _, __: contexts[rank]._zmq_broadcast(rank))
    assert results == [0] * size, "not all threads ran broadcast correctly"

    # Perform a local broadcast.
    results = do_parallel(
        lambda rank, _, __: contexts[rank]._zmq_broadcast_local(rank))
    expect = [rank - (rank % local_size) for rank in range(size)]  # type: Any

    assert results == expect, "not all threads ran broadcast_local correctly"

    # Perform a gather.
    results = do_parallel(
        lambda rank, _, __: set(contexts[rank]._zmq_gather(rank) or []))
    chief = set(range(size))
    expect = [set(range(size)) if rank == 0 else set() for rank in range(size)]
    assert results == [
        chief
    ] + [set()] * (size - 1), "not all threads ran gather correctly"

    # Perform a local gather.
    results = do_parallel(
        lambda rank, _, __: set(contexts[rank]._zmq_gather_local(rank) or []))
    expect = [
        set(range(rank, rank + local_size)) if rank %
        local_size == 0 else set() for rank in range(size)
    ]
    assert results == expect, "not all threads ran gather correctly"

    # Perform an allgather.
    results = do_parallel(
        lambda rank, _, __: set(contexts[rank]._zmq_allgather(rank)))
    expect = set(range(size))
    assert results == [expect
                       ] * size, "not all threads ran allgather correctly"

    # Perform a local allgather.
    results = do_parallel(
        lambda rank, _, __: set(contexts[rank]._zmq_allgather_local(rank)))
    expect = [
        set(range(cross_rank * local_size, (cross_rank + 1) * local_size))
        for cross_rank, _ in itertools.product(range(cross_size),
                                               range(local_size))
    ]
    assert results == expect, "not all threads ran allgather_local correctly"

    # Close all contexts.
    for context in contexts:
        context.close()