Esempio n. 1
0
def test_execute_single(ray_start_2_cpus):
    wg = WorkerGroup(num_workers=2)

    def f():
        import os
        os.environ["TEST"] = "1"

    wg.execute_single(1, f)

    def check():
        import os
        return os.environ.get("TEST", "0")

    assert wg.execute(check) == ["0", "1"]
Esempio n. 2
0
    def on_start(self, worker_group: WorkerGroup, backend_config: TorchConfig):
        if len(worker_group) > 1 and dist.is_available():
            # Set the appropriate training backend.
            if backend_config.backend is None:
                if worker_group.num_gpus_per_worker > 0:
                    backend = "nccl"
                else:
                    backend = "gloo"
            else:
                backend = backend_config.backend

            master_addr, master_port = worker_group.execute_single(
                0, get_address_and_port)
            if backend_config.init_method == "env":

                def set_env_vars(addr, port):
                    os.environ["MASTER_ADDR"] = addr
                    os.environ["MASTER_PORT"] = str(port)

                worker_group.execute(set_env_vars,
                                     addr=master_addr,
                                     port=master_port)
                url = "env://"
            elif backend_config.init_method == "tcp":
                url = f"tcp://{master_addr}:{master_port}"
            else:
                raise ValueError(
                    f"The provided init_method ("
                    f"{backend_config.init_method}) is not supported. Must "
                    f"be either 'env' or 'tcp'.")

            setup_futures = []
            for i in range(len(worker_group)):
                setup_futures.append(
                    worker_group.execute_single_async(
                        i,
                        setup_torch_process_group,
                        backend=backend,
                        world_rank=i,
                        world_size=len(worker_group),
                        init_method=url,
                        timeout_s=backend_config.timeout_s))
            ray.get(setup_futures)
        else:
            logger.info("Distributed torch is not being used.")