def setUp(self) -> None:
        # For testing, the default parameters used are for tcp. If a test
        # uses parameters for file store, we set the self._params to
        # self._params_filestore.
        self._params = RendezvousParameters(
            backend="dummy_backend",
            endpoint="localhost:29300",
            run_id="dummy_run_id",
            min_nodes=1,
            max_nodes=1,
            is_host="true",
            store_type="tCp",
            read_timeout="10",
        )

        _, tmp_path = tempfile.mkstemp()

        # Parameters for filestore testing.
        self._params_filestore = RendezvousParameters(
            backend="dummy_backend",
            endpoint=tmp_path,
            run_id="dummy_run_id",
            min_nodes=1,
            max_nodes=1,
            store_type="fIlE",
        )
        self._expected_endpoint_file = tmp_path
        self._expected_temp_dir = tempfile.gettempdir()

        self._expected_endpoint_host = "localhost"
        self._expected_endpoint_port = 29300
        self._expected_store_type = TCPStore
        self._expected_read_timeout = timedelta(seconds=10)
Beispiel #2
0
    def test_get_or_default(self):

        params = RendezvousParameters(
            backend="foobar",
            endpoint="localhost",
            run_id="1234",
            min_nodes=1,
            max_nodes=1,
            timeout1=10,
        )

        self.assertEqual(10, params.get("timeout1", 20))
        self.assertEqual(60, params.get("timeout2", 60))
Beispiel #3
0
    def _get_worker_spec(
        self,
        max_restarts=1,
        monitor_interval=1.0,
        role="test_trainer",
        local_world_size=8,
    ):
        run_id = str(uuid.uuid4().int)
        endpoint = self._etcd_server.get_endpoint()

        rdzv_params = RendezvousParameters(backend="etcd",
                                           endpoint=endpoint,
                                           run_id=run_id,
                                           min_nodes=1,
                                           max_nodes=1)
        rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params)
        spec = WorkerSpec(
            role=role,
            local_world_size=local_world_size,
            fn=do_nothing,
            args=(),
            rdzv_handler=rdzv_handler,
            max_restarts=max_restarts,
            monitor_interval=monitor_interval,
        )
        return spec
 def get_worker_spec(
     self,
     node_config: Conf,
     min_nodes=1,
     max_nodes=1,
     max_restarts=0,
     monitor_interval=0.01,
 ):
     rdzv_params = RendezvousParameters(
         backend="etcd",
         endpoint=self._etcd_server.get_endpoint(),
         run_id=self._run_id,
         min_nodes=min_nodes,
         max_nodes=max_nodes,
     )
     rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params)
     return WorkerSpec(
         role=node_config.role,
         local_world_size=node_config.local_world_size,
         entrypoint=node_config.entrypoint,
         args=node_config.args,
         rdzv_handler=rdzv_handler,
         max_restarts=max_restarts,
         monitor_interval=monitor_interval,
         redirects=node_config.redirects,
         tee=node_config.tee,
     )
Beispiel #5
0
 def _get_worker_spec(
     self,
     max_restarts=1,
     monitor_interval=1.0,
     role="test_trainer",
     local_world_size=8,
 ):
     run_id = str(uuid.uuid4().int)
     port = get_free_port()
     endpoint = f"127.0.0.1:{port}"
     rdzv_params = RendezvousParameters(
         backend="static",
         endpoint=endpoint,
         run_id=run_id,
         min_nodes=1,
         max_nodes=1,
         rank=0,
     )
     rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params)
     spec = WorkerSpec(
         role=role,
         local_world_size=local_world_size,
         fn=do_nothing,
         args=(),
         rdzv_handler=rdzv_handler,
         max_restarts=max_restarts,
         monitor_interval=monitor_interval,
     )
     return spec
Beispiel #6
0
    def test_static_rdzv_multiple_calls(self):
        sock = get_socket_with_port()
        with closing(sock):
            master_port = sock.getsockname()[1]
        master_addr = "localhost"

        rdzv_params = RendezvousParameters(
            backend="static",
            endpoint=f"{master_addr}:{master_port}",
            run_id="test_id",
            min_nodes=1,
            max_nodes=1,
            rank=0,
        )
        rdzv_handler = create_rdzv_handler(rdzv_params)

        # Call rendezvous two times
        store, rank, world_size = rdzv_handler.next_rendezvous()
        self.assertIsNotNone(store)
        self.assertEqual(0, rank)
        self.assertEqual(1, world_size)

        store, rank, world_size = rdzv_handler.next_rendezvous()
        self.assertIsNotNone(store)
        self.assertEqual(0, rank)
        self.assertEqual(1, world_size)
 def get_worker_spec(
     self,
     node_config: Conf,
     min_nodes=1,
     max_nodes=1,
     max_restarts=0,
     monitor_interval=0.01,
     master_addr_override: Optional[str] = None,
     master_port_override: Optional[int] = None,
     is_host=True,
 ):
     rdzv_params = RendezvousParameters(
         backend=self._backend,
         endpoint=self._endpoint,
         run_id=self._run_id,
         min_nodes=min_nodes,
         max_nodes=max_nodes,
         is_host=is_host,
     )
     rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params)
     return WorkerSpec(
         role=node_config.role,
         local_world_size=node_config.local_world_size,
         entrypoint=node_config.entrypoint,
         args=node_config.args,
         rdzv_handler=rdzv_handler,
         max_restarts=max_restarts,
         monitor_interval=monitor_interval,
         redirects=node_config.redirects,
         tee=node_config.tee,
         master_addr=master_addr_override,
         master_port=master_port_override,
     )
Beispiel #8
0
 def _create_params(self) -> RendezvousParameters:
     return RendezvousParameters(
         backend=self._backend,
         endpoint=self._endpoint,
         run_id=self._run_id,
         min_nodes=self._min_nodes,
         max_nodes=self._max_nodes,
         **self._kwargs,
     )
Beispiel #9
0
    def test_no_factory_method_found(self):
        factory = RendezvousHandlerFactory()
        rdzv_params = RendezvousParameters(backend="mock",
                                           endpoint="",
                                           run_id="foobar",
                                           min_nodes=1,
                                           max_nodes=2)

        with self.assertRaises(ValueError):
            factory.create_handler(rdzv_params)
Beispiel #10
0
    def setUp(self) -> None:
        self._params = RendezvousParameters(
            backend="dummy_backend",
            endpoint="dummy_endpoint",
            run_id="dummy_run_id",
            min_nodes=1,
            max_nodes=1,
        )

        self._registry = RendezvousHandlerRegistry()
Beispiel #11
0
 def test_ipv6_addr_localhost(self):
     rdzv_params = RendezvousParameters(
         backend="static",
         endpoint="[::1]:90",
         run_id="test_id",
         min_nodes=1,
         max_nodes=1,
     )
     with self.assertRaises(ValueError):
         create_rdzv_handler(rdzv_params)
Beispiel #12
0
 def test_empty_endpoint(self):
     rdzv_params = RendezvousParameters(
         backend="static",
         endpoint="",
         run_id="test_id",
         min_nodes=1,
         max_nodes=1,
     )
     with self.assertRaises(ValueError):
         create_rdzv_handler(rdzv_params)
Beispiel #13
0
 def test_ipv6_addr(self):
     rdzv_params = RendezvousParameters(
         backend="static",
         endpoint="[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:90",
         run_id="test_id",
         min_nodes=1,
         max_nodes=1,
     )
     with self.assertRaises(ValueError):
         create_rdzv_handler(rdzv_params)
Beispiel #14
0
    def test_create_handler(self):
        rdzv_params = RendezvousParameters(backend="mock",
                                           endpoint="",
                                           run_id="foobar",
                                           min_nodes=1,
                                           max_nodes=2)

        factory = RendezvousHandlerFactory()
        factory.register("mock", create_mock_rdzv_handler)
        mock_rdzv_handler = factory.create_handler(rdzv_params)
        self.assertTrue(isinstance(mock_rdzv_handler, MockRendezvousHandler))
    def setUp(self) -> None:
        self._params = RendezvousParameters(
            backend="dummy_backend",
            endpoint=self._server.get_endpoint(),
            run_id="dummy_run_id",
            min_nodes=1,
            max_nodes=1,
            protocol="hTTp",
            read_timeout="10",
        )

        self._expected_read_timeout = 10
Beispiel #16
0
    def test_get_backend(self):
        rdzv_params = RendezvousParameters(
            backend="static",
            endpoint="localhost:123",
            run_id="test",
            min_nodes=1,
            max_nodes=1,
            timeout=60,
            rank=0,
        )

        static_rdzv = create_rdzv_handler(rdzv_params)
        self.assertEqual("static", static_rdzv.get_backend())
Beispiel #17
0
 def test_etcd_rdzv_basic_params(self):
     """
     Check that we can create the handler with a minimum set of
     params
     """
     rdzv_params = RendezvousParameters(
         backend="etcd",
         endpoint=f"{self._etcd_server.get_endpoint()}",
         run_id=f"{uuid.uuid4()}",
         min_nodes=1,
         max_nodes=1,
     )
     etcd_rdzv = create_rdzv_handler(rdzv_params)
     self.assertIsNotNone(etcd_rdzv)
Beispiel #18
0
    def test_get_backend(self):
        run_id = str(uuid.uuid4())
        rdzv_params = RendezvousParameters(
            backend="etcd",
            endpoint=f"{self._etcd_server.get_endpoint()}",
            run_id=run_id,
            min_nodes=1,
            max_nodes=1,
            timeout=60,
            last_call_timeout=30,
            protocol="http",
        )

        etcd_rdzv = create_rdzv_handler(rdzv_params)

        self.assertEqual("etcd", etcd_rdzv.get_backend())
    def setUp(self) -> None:
        self._params = RendezvousParameters(
            backend="dummy_backend",
            endpoint="localhost:29400",
            run_id="dummy_run_id",
            min_nodes=1,
            max_nodes=1,
            is_host="true",
            store_type="tCp",
            read_timeout="10",
        )

        self._expected_endpoint_host = "localhost"
        self._expected_endpoint_port = 29400
        self._expected_store_type = TCPStore
        self._expected_read_timeout = timedelta(seconds=10)
    def setUp(self) -> None:
        self._store = DummyStore()

        self._backend = DummyRendezvousBackend()

        self._params = RendezvousParameters(
            backend=self._backend.name,
            endpoint="dummy_endpoint",
            run_id="dummy_run_id",
            min_nodes=3,
            max_nodes=6,
            join_timeout="50",
            last_call_timeout="60",
            close_timeout="70",
        )

        self._expected_timeout = RendezvousTimeout(timedelta(seconds=50),
                                                   timedelta(seconds=60),
                                                   timedelta(seconds=70))
Beispiel #21
0
def launch_agent(
    config: LaunchConfig,
    entrypoint: Union[Callable, str, None],
    args: List[Any],
) -> Dict[int, Any]:
    if not config.run_id:
        run_id = str(uuid.uuid4().int)
        logger.warning(
            f"config has no run_id, generated a random run_id: {run_id}")
        config.run_id = run_id

    entrypoint_name = _get_entrypoint_name(entrypoint, args)

    logger.info(f"Starting elastic_operator with launch configs:\n"
                f"  entrypoint       : {entrypoint_name}\n"
                f"  min_nodes        : {config.min_nodes}\n"
                f"  max_nodes        : {config.max_nodes}\n"
                f"  nproc_per_node   : {config.nproc_per_node}\n"
                f"  run_id           : {config.run_id}\n"
                f"  rdzv_backend     : {config.rdzv_backend}\n"
                f"  rdzv_endpoint    : {config.rdzv_endpoint}\n"
                f"  rdzv_configs     : {config.rdzv_configs}\n"
                f"  max_restarts     : {config.max_restarts}\n"
                f"  monitor_interval : {config.monitor_interval}\n"
                f"  log_dir          : {config.log_dir}\n"
                f"  metrics_cfg      : {config.metrics_cfg}\n")

    rdzv_parameters = RendezvousParameters(
        backend=config.rdzv_backend,
        endpoint=config.rdzv_endpoint,
        run_id=config.run_id,
        min_nodes=config.min_nodes,
        max_nodes=config.max_nodes,
        **config.rdzv_configs,
    )

    master_addr, master_port = _get_addr_and_port(rdzv_parameters)

    spec = WorkerSpec(
        role=config.role,
        local_world_size=config.nproc_per_node,
        entrypoint=entrypoint,
        args=tuple(args),
        rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
        max_restarts=config.max_restarts,
        monitor_interval=config.monitor_interval,
        redirects=config.redirects,
        tee=config.tee,
        master_addr=master_addr,
        master_port=master_port,
    )

    agent = LocalElasticAgent(spec=spec,
                              start_method=config.start_method,
                              log_dir=config.log_dir)

    shutdown_rdzv = True
    try:
        metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))

        result = agent.run()
        # records that agent.run() has succeeded NOT that workers have succeeded
        events.record(agent.get_event_succeeded())

        if result.is_failed():
            # ChildFailedError is treated specially by @record
            # if the error files for the failed children exist
            # @record will copy the first error (root cause)
            # to the error file of the launcher process.
            raise ChildFailedError(
                name=entrypoint_name,
                failures=result.failures,
            )

        return result.return_values
    except ChildFailedError:
        raise
    except SignalException:
        # when the agent dies with a signal do NOT shutdown the rdzv_handler
        # since this closes the rendezvous on this rdzv_id permanently and
        # prevents any additional scaling events
        shutdown_rdzv = False
        events.record(agent.get_event_failed())
        raise
    except Exception:
        events.record(agent.get_event_failed())
        raise
    finally:
        if shutdown_rdzv:
            spec.rdzv_handler.shutdown()
Beispiel #22
0
def main():
    args = parse_args()
    current_env = os.environ.copy()

    for k in current_env.keys():
        if "NCCL" in k:
            logger.info(f"{args.node_rank} {k}={current_env[k]}")

    if args.world_info == "None":
        raise ValueError("world_info can not be None")
    world_info = base64.urlsafe_b64decode(args.world_info)
    world_info = json.loads(world_info)

    logger.info(f"WORLD INFO DICT: {world_info}")
    node_list = list(world_info.keys())
    args.nnodes = len(node_list)
    local_node = node_list[args.node_rank]
    local_gpu_ids = world_info[local_node]
    num_local_procs = len(local_gpu_ids)
    logger.info(
        f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}"
    )

    global_rank_mapping = defaultdict(list)
    curr_global_rank = 0
    dist_world_size = 0
    for node_id in node_list:
        gids = world_info[node_id]
        dist_world_size += len(gids)
        for gid in gids:
            global_rank_mapping[node_id].append(curr_global_rank)
            curr_global_rank += 1
    logger.info(f"global_rank_mapping={global_rank_mapping}")
    logger.info(f"dist_world_size={dist_world_size}")
    current_env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, local_gpu_ids))
    logger.info(f"Setting CUDA_VISIBLE_DEVICES={current_env['CUDA_VISIBLE_DEVICES']}")

    # set PyTorch distributed related environmental variables
    current_env["MASTER_ADDR"] = args.master_addr
    current_env["MASTER_PORT"] = str(args.master_port)
    current_env["WORLD_SIZE"] = str(dist_world_size)
    current_env["CROSS_RANK"] = str(args.node_rank)
    current_env["CROSS_SIZE"] = str(args.nnodes)
    current_env["LOCAL_SIZE"] = str(num_local_procs)

    if args.save_pid:
        print(f"launcher pid: {os.getpid()}")

    pid_file = None
    if args.save_pid:
        launcher_pid = os.getpid()
        pid_file = os.path.join(PID_FILE_BASEPATH, f"{args.save_pid}.deepspeed")
        assert not os.path.isfile(pid_file), "pid file exists but shouldn't"
        with open(pid_file, 'w') as fd:
            fd.write(f"{launcher_pid}")

    if not is_torch_elastic_compatible():
        if args.enable_elastic_training:
            logger.info(f"Disabling elastic training support as \
                    PyTorch version should be greater than 1.11.x")
            args.enable_elastic_training = False

    if os.path.exists(DLTS_POD_ENV_PATH):
        with open(DLTS_POD_ENV_PATH) as file:
            lines = file.readlines()
            lines = [line.rstrip() for line in lines]
            for line in lines:
                if line.startswith('export FC_TASKROLE_NAME') or line.startswith(
                        'export FC_TASK_INDEX'):
                    key_val = line.split()[1]
                    key, val = key_val.split('=')
                    current_env[key] = val

    processes = []
    cmd = []

    if not args.enable_elastic_training:
        for local_rank in range(0, num_local_procs):
            # each process's rank
            dist_rank = global_rank_mapping[local_node][local_rank]
            current_env["RANK"] = str(dist_rank)
            current_env["LOCAL_RANK"] = str(local_rank)

            # spawn the processes
            cmd = []
            if not args.no_python:
                cmd = [sys.executable, "-u"]
                if args.module:
                    cmd.append("-m")
            else:
                if args.module:
                    raise ValueError("Don't use both the '--no_python' flag"
                                     " and the '--module' flag at the same time.")
            cmd.append(args.training_script)
            # A user may not want to pass local_rank as a keyword arg so we make this optional.
            if not args.no_local_rank:
                cmd.append(f"--local_rank={local_rank}")
            cmd += args.training_script_args

            process = subprocess.Popen(cmd, env=current_env)
            processes.append(process)
    else:
        from ..elasticity import DSElasticAgent
        from torch.distributed.elastic.rendezvous import RendezvousParameters
        from torch.distributed.elastic.agent.server.api import WorkerSpec
        import torch.distributed.elastic.rendezvous.registry as rdzv_registry
        from torch.distributed.elastic.multiprocessing import Std

        if args.min_elastic_nodes == -1:
            args.min_elastic_nodes = 1
        if args.max_elastic_nodes == -1:
            args.max_elastic_nodes = args.nnodes
        assert args.max_elastic_nodes > 0 and  args.min_elastic_nodes > 0 , "Max and Min nodes should be positive"

        current_env["NCCL_ASYNC_ERROR_HANDLING"] = str(1)

        # Get config and arguments
        cmd = []
        if not args.no_python:
            cmd = [sys.executable, "-u"]
            if args.module:
                cmd.append("-m")
        else:
            if args.module:
                raise ValueError("Don't use both the '--no_python' flag"
                                 " and the '--module' flag at the same time.")
        cmd.append(args.training_script)
        cmd += args.training_script_args
        cmd_args = cmd[1:]

        rdzv_configs: Dict[str, str] = {'timeout': 100}
        run_id = os.environ.get("ELASTIC_RUN_ID", ELASTIC_TRAINING_ID_DEFAULT)

        # Creating config for rendezvous class
        rdzv_parameters = RendezvousParameters(backend='c10d',
                                               endpoint=args.master_addr + ":" +
                                               str(args.master_port),
                                               run_id=run_id,
                                               min_nodes=args.min_elastic_nodes,
                                               max_nodes=args.max_elastic_nodes,
                                               **rdzv_configs)

        spec = WorkerSpec(
            role='trainer',
            local_world_size=num_local_procs,
            entrypoint=cmd[0],
            args=cmd[1:],
            rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
            max_restarts=100,
            monitor_interval=5,
            redirects=Std.from_str("0"),
            tee=Std.from_str("0"),
            master_addr=None,
            master_port=None,
        )
        agent = DSElasticAgent(spec, current_env)
        agent.run()

    sig_names = {2: "SIGINT", 15: "SIGTERM"}
    last_return_code = None

    def sigkill_handler(signum, frame):
        for process in processes:
            logger.info(f"Killing subprocess {process.pid}")
            try:
                terminate_process_tree(process.pid)
            except Exception:
                pass
        if last_return_code is not None:
            logger.error(f"{cmd} exits with return code = {last_return_code}")
            sys.exit(last_return_code)
        if signum in sig_names:
            logger.info(f"Main process received {sig_names[signum]}, exiting")
        if args.save_pid:
            if os.path.isfile(pid_file):
                os.remove(pid_file)
        sys.exit(1)

    # pass SIGINT/SIGTERM to children if the parent is being terminated
    signal.signal(signal.SIGINT, sigkill_handler)
    signal.signal(signal.SIGTERM, sigkill_handler)

    alive_processes = set(processes)
    while len(alive_processes):
        finished_processes = []
        for process in alive_processes:
            if process.poll() is None:
                # the process is still running
                continue
            else:
                if process.returncode != 0:
                    last_return_code = process.returncode  # for sigkill_handler
                    sigkill_handler(signal.SIGTERM, None)  # not coming back
                else:
                    # exited cleanly
                    logger.info(f"Process {process.pid} exits successfully.")
                    finished_processes.append(process)
        alive_processes = set(alive_processes) - set(finished_processes)

        time.sleep(1)
Beispiel #23
0
def main(args=None):
    # If ``args`` not passed, defaults to ``sys.argv[:1]``
    args = parse_args(args)
    min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
    assert 0 < min_nodes <= max_nodes
    assert args.max_restarts >= 0

    elastic_agent = None

    if args.standalone:
        etcd_server = EtcdServer()
        etcd_server.start()
        args.rdzv_backend = "etcd"
        args.rdzv_endpoint = etcd_server.get_endpoint()
        args.rdzv_id = str(uuid.uuid4())
        log.info(f"\n**************************************\n"
                 f"Rendezvous info:\n"
                 f"--rdzv_backend={args.rdzv_backend} "
                 f"--rdzv_endpoint={args.rdzv_endpoint} "
                 f"--rdzv_id={args.rdzv_id}\n"
                 f"**************************************\n")

    nproc_per_node = determine_local_world_size(args.nproc_per_node)
    if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1:
        omp_num_threads = 1
        print(
            f"*****************************************\n"
            f"Setting OMP_NUM_THREADS environment variable for each process to be "
            f"{omp_num_threads} in default, to avoid your system being overloaded, "
            f"please further tune the variable for optimal performance in "
            f"your application as needed. \n"
            f"*****************************************")
        # This env variable will be passed down to the subprocesses
        os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)

    with_python = not args.no_python
    cmd = []
    if with_python:
        cmd = [sys.executable, "-u"]
        if args.module:
            cmd.append("-m")
    else:
        if args.module:
            raise ValueError("Don't use both the '--no_python' flag"
                             " and the '--module' flag at the same time.")

    cmd.append(args.training_script)
    cmd.extend(args.training_script_args)

    rdzv_parameters = RendezvousParameters(
        backend=args.rdzv_backend,
        endpoint=args.rdzv_endpoint,
        run_id=args.rdzv_id,
        min_nodes=min_nodes,
        max_nodes=max_nodes,
        **_parse_rendezvous_config(args.rdzv_conf),
    )

    rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters)
    try:
        spec = WorkerSpec(
            role=args.role,
            local_world_size=nproc_per_node,
            entrypoint=cmd[0],
            args=(*cmd[1:], ),
            rdzv_handler=rdzv_handler,
            max_restarts=args.max_restarts,
            monitor_interval=args.monitor_interval,
            redirects=Std.from_str(args.redirects),
            tee=Std.from_str(args.tee),
        )
        metrics.initialize_metrics()
        elastic_agent = LocalElasticAgent(spec=spec,
                                          start_method=args.start_method,
                                          log_dir=args.log_dir)
        run_result = elastic_agent.run(spec.role)
        events.record(
            elastic_agent.get_agent_status_event(WorkerState.SUCCEEDED))
        if run_result.is_failed():
            # ChildFailedError is treated specially by @record
            # if the error files for the failed children exist
            # @record will copy the first error (root cause)
            # to the error file of the launcher process
            raise ChildFailedError(
                name=args.training_script,
                failures=run_result.failures,
            )
    except ChildFailedError:
        raise
    except Exception:
        if elastic_agent:
            events.record(
                elastic_agent.get_agent_status_event(WorkerState.FAILED))
        else:
            events.record(_construct_event(args))
        raise
    finally:
        rdzv_handler.shutdown()
        if args.standalone:
            etcd_server.stop()
Beispiel #24
0
def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
    """
    Usage:

    ::

    rdzv_params = RendezvousParameters(
                        backend="etcd",
                        endpoint="192.168.0.42:2379",
                        run_id="123",
                        min_nodes=4,
                        max_nodes=8,
                        timeout=300,
                        last_call_timeout=30,
                        etcd_prefix="custom_prefix",
                        protocol="https",
                        cacert="/etc/kubernetes/certs/ca.crt",
                        cert="/etc/kubernetes/certs/client.crt",
                        key="/etc/kubernetes/certs/client.key")
    # -- or --
    rdzv_params = RendezvousParameters(
                        backend="etcd",
                        endpoint="192.168.0.42:2379",
                        run_id="123",
                        min_nodes=4,
                        max_nodes=8)

    etcd_rdzv_handler = create_etcd_rendezvous_handler(rdzv_params)


    Where:
        run_id - unique id for this training job instance,
        min_nodes - min number of workers expected to join the rendezvous,
        max_nodes - max number of workers allowed to join the rendezvous,
                        defaults to min_workers is not specified.
        timeout - total timeout within which next_rendezvous is expected to
                      succeed; a RendezvousTimeoutError is raised otherwise;
                      Defaults is 600 (10 minutes).
        last_call_timeout - additional wait amount ("last call") after
                            min number of workers has been reached.
                            Defaults to 30 seconds.
        etcd_prefix - path prefix (from etcd root), inside which all
                      etcd nodes will be created.
                      Default is "/torchelastic/p2p".
        protocol - http (default) or https to access etcd.
        cacert - CA cert to access etcd, only makes sense with https.
        cert - client cert to access etcd, only makes sense with https.
        key - client key to access etcd, only makes sense with https.
    """
    client = _create_etcd_client(params)

    etcd_prefix = params.get("etcd_prefix", "/torchelastic/p2p")

    rdzv = EtcdRendezvous(
        client=client,
        prefix=etcd_prefix,
        run_id=params.run_id,
        num_min_workers=params.min_nodes,
        num_max_workers=params.max_nodes,
        timeout=params.get_as_int("timeout", _DEFAULT_TIMEOUT),
        last_call_timeout=params.get_as_int("last_call_timeout",
                                            _DEFAULT_LAST_CALL_TIMEOUT),
    )
    return EtcdRendezvousHandler(rdzv_impl=rdzv)
Beispiel #25
0
def launch_agent(
    config: LaunchConfig,
    entrypoint: Union[Callable, str, None],
    args: List[Any],
) -> Dict[int, Any]:
    if not config.run_id:
        run_id = str(uuid.uuid4().int)
        logger.warning(f"config has no run_id, generate a new one: {run_id}")
        config.run_id = run_id

    entrypoint_name = _get_entrypoint_name(entrypoint, args)

    logger.info(f"Starting elastic_operator with launch configs:\n"
                f"  entrypoint       : {entrypoint_name}\n"
                f"  min_nodes        : {config.min_nodes}\n"
                f"  max_nodes        : {config.max_nodes}\n"
                f"  nproc_per_node   : {config.nproc_per_node}\n"
                f"  run_id           : {config.run_id}\n"
                f"  rdzv_backend     : {config.rdzv_backend}\n"
                f"  rdzv_endpoint    : {config.rdzv_endpoint}\n"
                f"  rdzv_configs     : {config.rdzv_configs}\n"
                f"  max_restarts     : {config.max_restarts}\n"
                f"  monitor_interval : {config.monitor_interval}\n"
                f"  log_dir          : {config.log_dir}\n"
                f"  metrics_cfg      : {config.metrics_cfg}\n")

    rdzv_parameters = RendezvousParameters(
        backend=config.rdzv_backend,
        endpoint=config.rdzv_endpoint,
        run_id=config.run_id,
        min_nodes=config.min_nodes,
        max_nodes=config.max_nodes,
        **config.rdzv_configs,
    )

    agent = None
    rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters)
    master_addr, master_port = _get_addr_and_port(rdzv_parameters)
    try:
        spec = WorkerSpec(
            role=config.role,
            local_world_size=config.nproc_per_node,
            entrypoint=entrypoint,
            args=tuple(args),
            rdzv_handler=rdzv_handler,
            max_restarts=config.max_restarts,
            monitor_interval=config.monitor_interval,
            redirects=config.redirects,
            tee=config.tee,
            master_addr=master_addr,
            master_port=master_port,
        )

        cfg = metrics.MetricsConfig(
            config.metrics_cfg) if config.metrics_cfg else None
        metrics.initialize_metrics(cfg)

        agent = LocalElasticAgent(spec=spec,
                                  start_method=config.start_method,
                                  log_dir=config.log_dir)

        result = agent.run()
        events.record(agent.get_agent_status_event(WorkerState.SUCCEEDED))
        if result.is_failed():
            # ChildFailedError is treated specially by @record
            # if the error files for the failed children exist
            # @record will copy the first error (root cause)
            # to the error file of the launcher process.
            raise ChildFailedError(
                name=entrypoint_name,
                failures=result.failures,
            )
        else:
            return result.return_values
    except ChildFailedError:
        raise
    except Exception:
        if agent:
            events.record(agent.get_agent_status_event(WorkerState.FAILED))
        else:
            events.record(_construct_event(config))
        raise
    finally:
        rdzv_handler.shutdown()
class CreateBackendTest(TestCase):
    def setUp(self) -> None:
        # For testing, the default parameters used are for tcp. If a test
        # uses parameters for file store, we set the self._params to
        # self._params_filestore.
        self._params = RendezvousParameters(
            backend="dummy_backend",
            endpoint="localhost:29300",
            run_id="dummy_run_id",
            min_nodes=1,
            max_nodes=1,
            is_host="true",
            store_type="tCp",
            read_timeout="10",
        )

        _, tmp_path = tempfile.mkstemp()

        # Parameters for filestore testing.
        self._params_filestore = RendezvousParameters(
            backend="dummy_backend",
            endpoint=tmp_path,
            run_id="dummy_run_id",
            min_nodes=1,
            max_nodes=1,
            store_type="fIlE",
        )
        self._expected_endpoint_file = tmp_path
        self._expected_temp_dir = tempfile.gettempdir()

        self._expected_endpoint_host = "localhost"
        self._expected_endpoint_port = 29300
        self._expected_store_type = TCPStore
        self._expected_read_timeout = timedelta(seconds=10)

    def tearDown(self) -> None:
        os.remove(self._expected_endpoint_file)

    def _run_test_with_store(self, store_type: str, test_to_run: Callable):
        """
        Use this function to specify the store type to use in a test. If
        not used, the test will default to TCPStore.
        """
        if store_type == "file":
            self._params = self._params_filestore
            self._expected_store_type = FileStore
            self._expected_read_timeout = timedelta(seconds=300)

        test_to_run()

    def _assert_create_backend_returns_backend(self) -> None:
        backend, store = create_backend(self._params)

        self.assertEqual(backend.name, "c10d")

        self.assertIsInstance(store, self._expected_store_type)

        typecast_store = cast(self._expected_store_type, store)
        self.assertEqual(
            typecast_store.timeout,
            self._expected_read_timeout)  # type: ignore[attr-defined]
        if (self._expected_store_type == TCPStore):
            self.assertEqual(
                typecast_store.host,
                self._expected_endpoint_host)  # type: ignore[attr-defined]
            self.assertEqual(
                typecast_store.port,
                self._expected_endpoint_port)  # type: ignore[attr-defined]
        if (self._expected_store_type == FileStore):
            if self._params.endpoint:
                self.assertEqual(
                    typecast_store.path,
                    self._expected_endpoint_file)  # type: ignore[attr-defined]
            else:
                self.assertTrue(
                    typecast_store.path.startswith(
                        self._expected_temp_dir))  # type: ignore[attr-defined]

        backend.set_state(b"dummy_state")

        state = store.get("torch.rendezvous." + self._params.run_id)

        self.assertEqual(state, b64encode(b"dummy_state"))

    def test_create_backend_returns_backend(self) -> None:
        for store_type in ["tcp", "file"]:
            with self.subTest(store_type=store_type):
                self._run_test_with_store(
                    store_type, self._assert_create_backend_returns_backend)

    def test_create_backend_returns_backend_if_is_host_is_false(self) -> None:
        store = TCPStore(  # type: ignore[call-arg] # noqa: F841
            self._expected_endpoint_host,
            self._expected_endpoint_port,
            is_master=True)

        self._params.config["is_host"] = "false"

        self._assert_create_backend_returns_backend()

    def test_create_backend_returns_backend_if_is_host_is_not_specified(
            self) -> None:
        del self._params.config["is_host"]

        self._assert_create_backend_returns_backend()

    def test_create_backend_returns_backend_if_is_host_is_not_specified_and_store_already_exists(
        self, ) -> None:
        store = TCPStore(  # type: ignore[call-arg] # noqa: F841
            self._expected_endpoint_host,
            self._expected_endpoint_port,
            is_master=True)

        del self._params.config["is_host"]

        self._assert_create_backend_returns_backend()

    def test_create_backend_returns_backend_if_endpoint_port_is_not_specified(
            self) -> None:
        self._params.endpoint = self._expected_endpoint_host

        self._expected_endpoint_port = 29400

        self._assert_create_backend_returns_backend()

    def test_create_backend_returns_backend_if_endpoint_file_is_not_specified(
            self) -> None:
        self._params_filestore.endpoint = ""

        self._run_test_with_store("file",
                                  self._assert_create_backend_returns_backend)

    def test_create_backend_returns_backend_if_store_type_is_not_specified(
            self) -> None:
        del self._params.config["store_type"]

        self._expected_store_type = TCPStore
        if (not self._params.get("read_timeout")):
            self._expected_read_timeout = timedelta(seconds=60)

        self._assert_create_backend_returns_backend()

    def test_create_backend_returns_backend_if_read_timeout_is_not_specified(
            self) -> None:
        del self._params.config["read_timeout"]

        self._expected_read_timeout = timedelta(seconds=60)

        self._assert_create_backend_returns_backend()

    def test_create_backend_raises_error_if_store_is_unreachable(self) -> None:
        self._params.config["is_host"] = "false"
        self._params.config["read_timeout"] = "2"

        with self.assertRaisesRegex(
                RendezvousConnectionError,
                r"^The connection to the C10d store has failed. See inner exception for details.$",
        ):
            create_backend(self._params)

    def test_create_backend_raises_error_if_endpoint_is_invalid(self) -> None:
        for is_host in [True, False]:
            with self.subTest(is_host=is_host):
                self._params.config["is_host"] = str(is_host)

                self._params.endpoint = "dummy_endpoint"

                with self.assertRaisesRegex(
                        RendezvousConnectionError,
                        r"^The connection to the C10d store has failed. See inner exception for "
                        r"details.$",
                ):
                    create_backend(self._params)

    def test_create_backend_raises_error_if_store_type_is_invalid(
            self) -> None:
        self._params.config["store_type"] = "dummy_store_type"

        with self.assertRaisesRegex(
                ValueError,
                r"^Invalid store type given. Currently only supports file and tcp.$"
        ):
            create_backend(self._params)

    def test_create_backend_raises_error_if_read_timeout_is_invalid(
            self) -> None:
        for read_timeout in ["0", "-10"]:
            with self.subTest(read_timeout=read_timeout):
                self._params.config["read_timeout"] = read_timeout

                with self.assertRaisesRegex(
                        ValueError,
                        r"^The read timeout must be a positive integer.$"):
                    create_backend(self._params)

    @mock.patch("tempfile.mkstemp")
    def test_create_backend_raises_error_if_tempfile_creation_fails(
            self, tempfile_mock) -> None:
        tempfile_mock.side_effect = OSError("test error")
        # Set the endpoint to empty so it defaults to creating a temp file
        self._params_filestore.endpoint = ""
        with self.assertRaisesRegex(
                RendezvousError,
                r"The file creation for C10d store has failed. See inner exception for details."
        ):
            create_backend(self._params_filestore)

    @mock.patch(
        "torch.distributed.elastic.rendezvous.c10d_rendezvous_backend.FileStore"
    )
    def test_create_backend_raises_error_if_file_path_is_invalid(
            self, filestore_mock) -> None:
        filestore_mock.side_effect = RuntimeError("test error")
        self._params_filestore.endpoint = "bad file path"
        with self.assertRaisesRegex(
                RendezvousConnectionError,
                r"^The connection to the C10d store has failed. See inner exception for "
                r"details.$",
        ):
            create_backend(self._params_filestore)