def setUp(self) -> None: # For testing, the default parameters used are for tcp. If a test # uses parameters for file store, we set the self._params to # self._params_filestore. self._params = RendezvousParameters( backend="dummy_backend", endpoint="localhost:29300", run_id="dummy_run_id", min_nodes=1, max_nodes=1, is_host="true", store_type="tCp", read_timeout="10", ) _, tmp_path = tempfile.mkstemp() # Parameters for filestore testing. self._params_filestore = RendezvousParameters( backend="dummy_backend", endpoint=tmp_path, run_id="dummy_run_id", min_nodes=1, max_nodes=1, store_type="fIlE", ) self._expected_endpoint_file = tmp_path self._expected_temp_dir = tempfile.gettempdir() self._expected_endpoint_host = "localhost" self._expected_endpoint_port = 29300 self._expected_store_type = TCPStore self._expected_read_timeout = timedelta(seconds=10)
def test_get_or_default(self): params = RendezvousParameters( backend="foobar", endpoint="localhost", run_id="1234", min_nodes=1, max_nodes=1, timeout1=10, ) self.assertEqual(10, params.get("timeout1", 20)) self.assertEqual(60, params.get("timeout2", 60))
def _get_worker_spec( self, max_restarts=1, monitor_interval=1.0, role="test_trainer", local_world_size=8, ): run_id = str(uuid.uuid4().int) endpoint = self._etcd_server.get_endpoint() rdzv_params = RendezvousParameters(backend="etcd", endpoint=endpoint, run_id=run_id, min_nodes=1, max_nodes=1) rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params) spec = WorkerSpec( role=role, local_world_size=local_world_size, fn=do_nothing, args=(), rdzv_handler=rdzv_handler, max_restarts=max_restarts, monitor_interval=monitor_interval, ) return spec
def get_worker_spec( self, node_config: Conf, min_nodes=1, max_nodes=1, max_restarts=0, monitor_interval=0.01, ): rdzv_params = RendezvousParameters( backend="etcd", endpoint=self._etcd_server.get_endpoint(), run_id=self._run_id, min_nodes=min_nodes, max_nodes=max_nodes, ) rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params) return WorkerSpec( role=node_config.role, local_world_size=node_config.local_world_size, entrypoint=node_config.entrypoint, args=node_config.args, rdzv_handler=rdzv_handler, max_restarts=max_restarts, monitor_interval=monitor_interval, redirects=node_config.redirects, tee=node_config.tee, )
def _get_worker_spec( self, max_restarts=1, monitor_interval=1.0, role="test_trainer", local_world_size=8, ): run_id = str(uuid.uuid4().int) port = get_free_port() endpoint = f"127.0.0.1:{port}" rdzv_params = RendezvousParameters( backend="static", endpoint=endpoint, run_id=run_id, min_nodes=1, max_nodes=1, rank=0, ) rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params) spec = WorkerSpec( role=role, local_world_size=local_world_size, fn=do_nothing, args=(), rdzv_handler=rdzv_handler, max_restarts=max_restarts, monitor_interval=monitor_interval, ) return spec
def test_static_rdzv_multiple_calls(self): sock = get_socket_with_port() with closing(sock): master_port = sock.getsockname()[1] master_addr = "localhost" rdzv_params = RendezvousParameters( backend="static", endpoint=f"{master_addr}:{master_port}", run_id="test_id", min_nodes=1, max_nodes=1, rank=0, ) rdzv_handler = create_rdzv_handler(rdzv_params) # Call rendezvous two times store, rank, world_size = rdzv_handler.next_rendezvous() self.assertIsNotNone(store) self.assertEqual(0, rank) self.assertEqual(1, world_size) store, rank, world_size = rdzv_handler.next_rendezvous() self.assertIsNotNone(store) self.assertEqual(0, rank) self.assertEqual(1, world_size)
def get_worker_spec( self, node_config: Conf, min_nodes=1, max_nodes=1, max_restarts=0, monitor_interval=0.01, master_addr_override: Optional[str] = None, master_port_override: Optional[int] = None, is_host=True, ): rdzv_params = RendezvousParameters( backend=self._backend, endpoint=self._endpoint, run_id=self._run_id, min_nodes=min_nodes, max_nodes=max_nodes, is_host=is_host, ) rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params) return WorkerSpec( role=node_config.role, local_world_size=node_config.local_world_size, entrypoint=node_config.entrypoint, args=node_config.args, rdzv_handler=rdzv_handler, max_restarts=max_restarts, monitor_interval=monitor_interval, redirects=node_config.redirects, tee=node_config.tee, master_addr=master_addr_override, master_port=master_port_override, )
def _create_params(self) -> RendezvousParameters: return RendezvousParameters( backend=self._backend, endpoint=self._endpoint, run_id=self._run_id, min_nodes=self._min_nodes, max_nodes=self._max_nodes, **self._kwargs, )
def test_no_factory_method_found(self): factory = RendezvousHandlerFactory() rdzv_params = RendezvousParameters(backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2) with self.assertRaises(ValueError): factory.create_handler(rdzv_params)
def setUp(self) -> None: self._params = RendezvousParameters( backend="dummy_backend", endpoint="dummy_endpoint", run_id="dummy_run_id", min_nodes=1, max_nodes=1, ) self._registry = RendezvousHandlerRegistry()
def test_ipv6_addr_localhost(self): rdzv_params = RendezvousParameters( backend="static", endpoint="[::1]:90", run_id="test_id", min_nodes=1, max_nodes=1, ) with self.assertRaises(ValueError): create_rdzv_handler(rdzv_params)
def test_empty_endpoint(self): rdzv_params = RendezvousParameters( backend="static", endpoint="", run_id="test_id", min_nodes=1, max_nodes=1, ) with self.assertRaises(ValueError): create_rdzv_handler(rdzv_params)
def test_ipv6_addr(self): rdzv_params = RendezvousParameters( backend="static", endpoint="[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:90", run_id="test_id", min_nodes=1, max_nodes=1, ) with self.assertRaises(ValueError): create_rdzv_handler(rdzv_params)
def test_create_handler(self): rdzv_params = RendezvousParameters(backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2) factory = RendezvousHandlerFactory() factory.register("mock", create_mock_rdzv_handler) mock_rdzv_handler = factory.create_handler(rdzv_params) self.assertTrue(isinstance(mock_rdzv_handler, MockRendezvousHandler))
def setUp(self) -> None: self._params = RendezvousParameters( backend="dummy_backend", endpoint=self._server.get_endpoint(), run_id="dummy_run_id", min_nodes=1, max_nodes=1, protocol="hTTp", read_timeout="10", ) self._expected_read_timeout = 10
def test_get_backend(self): rdzv_params = RendezvousParameters( backend="static", endpoint="localhost:123", run_id="test", min_nodes=1, max_nodes=1, timeout=60, rank=0, ) static_rdzv = create_rdzv_handler(rdzv_params) self.assertEqual("static", static_rdzv.get_backend())
def test_etcd_rdzv_basic_params(self): """ Check that we can create the handler with a minimum set of params """ rdzv_params = RendezvousParameters( backend="etcd", endpoint=f"{self._etcd_server.get_endpoint()}", run_id=f"{uuid.uuid4()}", min_nodes=1, max_nodes=1, ) etcd_rdzv = create_rdzv_handler(rdzv_params) self.assertIsNotNone(etcd_rdzv)
def test_get_backend(self): run_id = str(uuid.uuid4()) rdzv_params = RendezvousParameters( backend="etcd", endpoint=f"{self._etcd_server.get_endpoint()}", run_id=run_id, min_nodes=1, max_nodes=1, timeout=60, last_call_timeout=30, protocol="http", ) etcd_rdzv = create_rdzv_handler(rdzv_params) self.assertEqual("etcd", etcd_rdzv.get_backend())
def setUp(self) -> None: self._params = RendezvousParameters( backend="dummy_backend", endpoint="localhost:29400", run_id="dummy_run_id", min_nodes=1, max_nodes=1, is_host="true", store_type="tCp", read_timeout="10", ) self._expected_endpoint_host = "localhost" self._expected_endpoint_port = 29400 self._expected_store_type = TCPStore self._expected_read_timeout = timedelta(seconds=10)
def setUp(self) -> None: self._store = DummyStore() self._backend = DummyRendezvousBackend() self._params = RendezvousParameters( backend=self._backend.name, endpoint="dummy_endpoint", run_id="dummy_run_id", min_nodes=3, max_nodes=6, join_timeout="50", last_call_timeout="60", close_timeout="70", ) self._expected_timeout = RendezvousTimeout(timedelta(seconds=50), timedelta(seconds=60), timedelta(seconds=70))
def launch_agent( config: LaunchConfig, entrypoint: Union[Callable, str, None], args: List[Any], ) -> Dict[int, Any]: if not config.run_id: run_id = str(uuid.uuid4().int) logger.warning( f"config has no run_id, generated a random run_id: {run_id}") config.run_id = run_id entrypoint_name = _get_entrypoint_name(entrypoint, args) logger.info(f"Starting elastic_operator with launch configs:\n" f" entrypoint : {entrypoint_name}\n" f" min_nodes : {config.min_nodes}\n" f" max_nodes : {config.max_nodes}\n" f" nproc_per_node : {config.nproc_per_node}\n" f" run_id : {config.run_id}\n" f" rdzv_backend : {config.rdzv_backend}\n" f" rdzv_endpoint : {config.rdzv_endpoint}\n" f" rdzv_configs : {config.rdzv_configs}\n" f" max_restarts : {config.max_restarts}\n" f" monitor_interval : {config.monitor_interval}\n" f" log_dir : {config.log_dir}\n" f" metrics_cfg : {config.metrics_cfg}\n") rdzv_parameters = RendezvousParameters( backend=config.rdzv_backend, endpoint=config.rdzv_endpoint, run_id=config.run_id, min_nodes=config.min_nodes, max_nodes=config.max_nodes, **config.rdzv_configs, ) master_addr, master_port = _get_addr_and_port(rdzv_parameters) spec = WorkerSpec( role=config.role, local_world_size=config.nproc_per_node, entrypoint=entrypoint, args=tuple(args), rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), max_restarts=config.max_restarts, monitor_interval=config.monitor_interval, redirects=config.redirects, tee=config.tee, master_addr=master_addr, master_port=master_port, ) agent = LocalElasticAgent(spec=spec, start_method=config.start_method, log_dir=config.log_dir) shutdown_rdzv = True try: metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) result = agent.run() # records that agent.run() has succeeded NOT that workers have succeeded events.record(agent.get_event_succeeded()) if result.is_failed(): # ChildFailedError is treated specially by @record # if the error files for the failed children exist # @record will copy the first error (root cause) # to the error file of the launcher process. raise ChildFailedError( name=entrypoint_name, failures=result.failures, ) return result.return_values except ChildFailedError: raise except SignalException: # when the agent dies with a signal do NOT shutdown the rdzv_handler # since this closes the rendezvous on this rdzv_id permanently and # prevents any additional scaling events shutdown_rdzv = False events.record(agent.get_event_failed()) raise except Exception: events.record(agent.get_event_failed()) raise finally: if shutdown_rdzv: spec.rdzv_handler.shutdown()
def main(): args = parse_args() current_env = os.environ.copy() for k in current_env.keys(): if "NCCL" in k: logger.info(f"{args.node_rank} {k}={current_env[k]}") if args.world_info == "None": raise ValueError("world_info can not be None") world_info = base64.urlsafe_b64decode(args.world_info) world_info = json.loads(world_info) logger.info(f"WORLD INFO DICT: {world_info}") node_list = list(world_info.keys()) args.nnodes = len(node_list) local_node = node_list[args.node_rank] local_gpu_ids = world_info[local_node] num_local_procs = len(local_gpu_ids) logger.info( f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}" ) global_rank_mapping = defaultdict(list) curr_global_rank = 0 dist_world_size = 0 for node_id in node_list: gids = world_info[node_id] dist_world_size += len(gids) for gid in gids: global_rank_mapping[node_id].append(curr_global_rank) curr_global_rank += 1 logger.info(f"global_rank_mapping={global_rank_mapping}") logger.info(f"dist_world_size={dist_world_size}") current_env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, local_gpu_ids)) logger.info(f"Setting CUDA_VISIBLE_DEVICES={current_env['CUDA_VISIBLE_DEVICES']}") # set PyTorch distributed related environmental variables current_env["MASTER_ADDR"] = args.master_addr current_env["MASTER_PORT"] = str(args.master_port) current_env["WORLD_SIZE"] = str(dist_world_size) current_env["CROSS_RANK"] = str(args.node_rank) current_env["CROSS_SIZE"] = str(args.nnodes) current_env["LOCAL_SIZE"] = str(num_local_procs) if args.save_pid: print(f"launcher pid: {os.getpid()}") pid_file = None if args.save_pid: launcher_pid = os.getpid() pid_file = os.path.join(PID_FILE_BASEPATH, f"{args.save_pid}.deepspeed") assert not os.path.isfile(pid_file), "pid file exists but shouldn't" with open(pid_file, 'w') as fd: fd.write(f"{launcher_pid}") if not is_torch_elastic_compatible(): if args.enable_elastic_training: logger.info(f"Disabling elastic training support as \ PyTorch version should be greater than 1.11.x") args.enable_elastic_training = False if os.path.exists(DLTS_POD_ENV_PATH): with open(DLTS_POD_ENV_PATH) as file: lines = file.readlines() lines = [line.rstrip() for line in lines] for line in lines: if line.startswith('export FC_TASKROLE_NAME') or line.startswith( 'export FC_TASK_INDEX'): key_val = line.split()[1] key, val = key_val.split('=') current_env[key] = val processes = [] cmd = [] if not args.enable_elastic_training: for local_rank in range(0, num_local_procs): # each process's rank dist_rank = global_rank_mapping[local_node][local_rank] current_env["RANK"] = str(dist_rank) current_env["LOCAL_RANK"] = str(local_rank) # spawn the processes cmd = [] if not args.no_python: cmd = [sys.executable, "-u"] if args.module: cmd.append("-m") else: if args.module: raise ValueError("Don't use both the '--no_python' flag" " and the '--module' flag at the same time.") cmd.append(args.training_script) # A user may not want to pass local_rank as a keyword arg so we make this optional. if not args.no_local_rank: cmd.append(f"--local_rank={local_rank}") cmd += args.training_script_args process = subprocess.Popen(cmd, env=current_env) processes.append(process) else: from ..elasticity import DSElasticAgent from torch.distributed.elastic.rendezvous import RendezvousParameters from torch.distributed.elastic.agent.server.api import WorkerSpec import torch.distributed.elastic.rendezvous.registry as rdzv_registry from torch.distributed.elastic.multiprocessing import Std if args.min_elastic_nodes == -1: args.min_elastic_nodes = 1 if args.max_elastic_nodes == -1: args.max_elastic_nodes = args.nnodes assert args.max_elastic_nodes > 0 and args.min_elastic_nodes > 0 , "Max and Min nodes should be positive" current_env["NCCL_ASYNC_ERROR_HANDLING"] = str(1) # Get config and arguments cmd = [] if not args.no_python: cmd = [sys.executable, "-u"] if args.module: cmd.append("-m") else: if args.module: raise ValueError("Don't use both the '--no_python' flag" " and the '--module' flag at the same time.") cmd.append(args.training_script) cmd += args.training_script_args cmd_args = cmd[1:] rdzv_configs: Dict[str, str] = {'timeout': 100} run_id = os.environ.get("ELASTIC_RUN_ID", ELASTIC_TRAINING_ID_DEFAULT) # Creating config for rendezvous class rdzv_parameters = RendezvousParameters(backend='c10d', endpoint=args.master_addr + ":" + str(args.master_port), run_id=run_id, min_nodes=args.min_elastic_nodes, max_nodes=args.max_elastic_nodes, **rdzv_configs) spec = WorkerSpec( role='trainer', local_world_size=num_local_procs, entrypoint=cmd[0], args=cmd[1:], rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), max_restarts=100, monitor_interval=5, redirects=Std.from_str("0"), tee=Std.from_str("0"), master_addr=None, master_port=None, ) agent = DSElasticAgent(spec, current_env) agent.run() sig_names = {2: "SIGINT", 15: "SIGTERM"} last_return_code = None def sigkill_handler(signum, frame): for process in processes: logger.info(f"Killing subprocess {process.pid}") try: terminate_process_tree(process.pid) except Exception: pass if last_return_code is not None: logger.error(f"{cmd} exits with return code = {last_return_code}") sys.exit(last_return_code) if signum in sig_names: logger.info(f"Main process received {sig_names[signum]}, exiting") if args.save_pid: if os.path.isfile(pid_file): os.remove(pid_file) sys.exit(1) # pass SIGINT/SIGTERM to children if the parent is being terminated signal.signal(signal.SIGINT, sigkill_handler) signal.signal(signal.SIGTERM, sigkill_handler) alive_processes = set(processes) while len(alive_processes): finished_processes = [] for process in alive_processes: if process.poll() is None: # the process is still running continue else: if process.returncode != 0: last_return_code = process.returncode # for sigkill_handler sigkill_handler(signal.SIGTERM, None) # not coming back else: # exited cleanly logger.info(f"Process {process.pid} exits successfully.") finished_processes.append(process) alive_processes = set(alive_processes) - set(finished_processes) time.sleep(1)
def main(args=None): # If ``args`` not passed, defaults to ``sys.argv[:1]`` args = parse_args(args) min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes) assert 0 < min_nodes <= max_nodes assert args.max_restarts >= 0 elastic_agent = None if args.standalone: etcd_server = EtcdServer() etcd_server.start() args.rdzv_backend = "etcd" args.rdzv_endpoint = etcd_server.get_endpoint() args.rdzv_id = str(uuid.uuid4()) log.info(f"\n**************************************\n" f"Rendezvous info:\n" f"--rdzv_backend={args.rdzv_backend} " f"--rdzv_endpoint={args.rdzv_endpoint} " f"--rdzv_id={args.rdzv_id}\n" f"**************************************\n") nproc_per_node = determine_local_world_size(args.nproc_per_node) if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1: omp_num_threads = 1 print( f"*****************************************\n" f"Setting OMP_NUM_THREADS environment variable for each process to be " f"{omp_num_threads} in default, to avoid your system being overloaded, " f"please further tune the variable for optimal performance in " f"your application as needed. \n" f"*****************************************") # This env variable will be passed down to the subprocesses os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) with_python = not args.no_python cmd = [] if with_python: cmd = [sys.executable, "-u"] if args.module: cmd.append("-m") else: if args.module: raise ValueError("Don't use both the '--no_python' flag" " and the '--module' flag at the same time.") cmd.append(args.training_script) cmd.extend(args.training_script_args) rdzv_parameters = RendezvousParameters( backend=args.rdzv_backend, endpoint=args.rdzv_endpoint, run_id=args.rdzv_id, min_nodes=min_nodes, max_nodes=max_nodes, **_parse_rendezvous_config(args.rdzv_conf), ) rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters) try: spec = WorkerSpec( role=args.role, local_world_size=nproc_per_node, entrypoint=cmd[0], args=(*cmd[1:], ), rdzv_handler=rdzv_handler, max_restarts=args.max_restarts, monitor_interval=args.monitor_interval, redirects=Std.from_str(args.redirects), tee=Std.from_str(args.tee), ) metrics.initialize_metrics() elastic_agent = LocalElasticAgent(spec=spec, start_method=args.start_method, log_dir=args.log_dir) run_result = elastic_agent.run(spec.role) events.record( elastic_agent.get_agent_status_event(WorkerState.SUCCEEDED)) if run_result.is_failed(): # ChildFailedError is treated specially by @record # if the error files for the failed children exist # @record will copy the first error (root cause) # to the error file of the launcher process raise ChildFailedError( name=args.training_script, failures=run_result.failures, ) except ChildFailedError: raise except Exception: if elastic_agent: events.record( elastic_agent.get_agent_status_event(WorkerState.FAILED)) else: events.record(_construct_event(args)) raise finally: rdzv_handler.shutdown() if args.standalone: etcd_server.stop()
def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: """ Usage: :: rdzv_params = RendezvousParameters( backend="etcd", endpoint="192.168.0.42:2379", run_id="123", min_nodes=4, max_nodes=8, timeout=300, last_call_timeout=30, etcd_prefix="custom_prefix", protocol="https", cacert="/etc/kubernetes/certs/ca.crt", cert="/etc/kubernetes/certs/client.crt", key="/etc/kubernetes/certs/client.key") # -- or -- rdzv_params = RendezvousParameters( backend="etcd", endpoint="192.168.0.42:2379", run_id="123", min_nodes=4, max_nodes=8) etcd_rdzv_handler = create_etcd_rendezvous_handler(rdzv_params) Where: run_id - unique id for this training job instance, min_nodes - min number of workers expected to join the rendezvous, max_nodes - max number of workers allowed to join the rendezvous, defaults to min_workers is not specified. timeout - total timeout within which next_rendezvous is expected to succeed; a RendezvousTimeoutError is raised otherwise; Defaults is 600 (10 minutes). last_call_timeout - additional wait amount ("last call") after min number of workers has been reached. Defaults to 30 seconds. etcd_prefix - path prefix (from etcd root), inside which all etcd nodes will be created. Default is "/torchelastic/p2p". protocol - http (default) or https to access etcd. cacert - CA cert to access etcd, only makes sense with https. cert - client cert to access etcd, only makes sense with https. key - client key to access etcd, only makes sense with https. """ client = _create_etcd_client(params) etcd_prefix = params.get("etcd_prefix", "/torchelastic/p2p") rdzv = EtcdRendezvous( client=client, prefix=etcd_prefix, run_id=params.run_id, num_min_workers=params.min_nodes, num_max_workers=params.max_nodes, timeout=params.get_as_int("timeout", _DEFAULT_TIMEOUT), last_call_timeout=params.get_as_int("last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT), ) return EtcdRendezvousHandler(rdzv_impl=rdzv)
def launch_agent( config: LaunchConfig, entrypoint: Union[Callable, str, None], args: List[Any], ) -> Dict[int, Any]: if not config.run_id: run_id = str(uuid.uuid4().int) logger.warning(f"config has no run_id, generate a new one: {run_id}") config.run_id = run_id entrypoint_name = _get_entrypoint_name(entrypoint, args) logger.info(f"Starting elastic_operator with launch configs:\n" f" entrypoint : {entrypoint_name}\n" f" min_nodes : {config.min_nodes}\n" f" max_nodes : {config.max_nodes}\n" f" nproc_per_node : {config.nproc_per_node}\n" f" run_id : {config.run_id}\n" f" rdzv_backend : {config.rdzv_backend}\n" f" rdzv_endpoint : {config.rdzv_endpoint}\n" f" rdzv_configs : {config.rdzv_configs}\n" f" max_restarts : {config.max_restarts}\n" f" monitor_interval : {config.monitor_interval}\n" f" log_dir : {config.log_dir}\n" f" metrics_cfg : {config.metrics_cfg}\n") rdzv_parameters = RendezvousParameters( backend=config.rdzv_backend, endpoint=config.rdzv_endpoint, run_id=config.run_id, min_nodes=config.min_nodes, max_nodes=config.max_nodes, **config.rdzv_configs, ) agent = None rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters) master_addr, master_port = _get_addr_and_port(rdzv_parameters) try: spec = WorkerSpec( role=config.role, local_world_size=config.nproc_per_node, entrypoint=entrypoint, args=tuple(args), rdzv_handler=rdzv_handler, max_restarts=config.max_restarts, monitor_interval=config.monitor_interval, redirects=config.redirects, tee=config.tee, master_addr=master_addr, master_port=master_port, ) cfg = metrics.MetricsConfig( config.metrics_cfg) if config.metrics_cfg else None metrics.initialize_metrics(cfg) agent = LocalElasticAgent(spec=spec, start_method=config.start_method, log_dir=config.log_dir) result = agent.run() events.record(agent.get_agent_status_event(WorkerState.SUCCEEDED)) if result.is_failed(): # ChildFailedError is treated specially by @record # if the error files for the failed children exist # @record will copy the first error (root cause) # to the error file of the launcher process. raise ChildFailedError( name=entrypoint_name, failures=result.failures, ) else: return result.return_values except ChildFailedError: raise except Exception: if agent: events.record(agent.get_agent_status_event(WorkerState.FAILED)) else: events.record(_construct_event(config)) raise finally: rdzv_handler.shutdown()
class CreateBackendTest(TestCase): def setUp(self) -> None: # For testing, the default parameters used are for tcp. If a test # uses parameters for file store, we set the self._params to # self._params_filestore. self._params = RendezvousParameters( backend="dummy_backend", endpoint="localhost:29300", run_id="dummy_run_id", min_nodes=1, max_nodes=1, is_host="true", store_type="tCp", read_timeout="10", ) _, tmp_path = tempfile.mkstemp() # Parameters for filestore testing. self._params_filestore = RendezvousParameters( backend="dummy_backend", endpoint=tmp_path, run_id="dummy_run_id", min_nodes=1, max_nodes=1, store_type="fIlE", ) self._expected_endpoint_file = tmp_path self._expected_temp_dir = tempfile.gettempdir() self._expected_endpoint_host = "localhost" self._expected_endpoint_port = 29300 self._expected_store_type = TCPStore self._expected_read_timeout = timedelta(seconds=10) def tearDown(self) -> None: os.remove(self._expected_endpoint_file) def _run_test_with_store(self, store_type: str, test_to_run: Callable): """ Use this function to specify the store type to use in a test. If not used, the test will default to TCPStore. """ if store_type == "file": self._params = self._params_filestore self._expected_store_type = FileStore self._expected_read_timeout = timedelta(seconds=300) test_to_run() def _assert_create_backend_returns_backend(self) -> None: backend, store = create_backend(self._params) self.assertEqual(backend.name, "c10d") self.assertIsInstance(store, self._expected_store_type) typecast_store = cast(self._expected_store_type, store) self.assertEqual( typecast_store.timeout, self._expected_read_timeout) # type: ignore[attr-defined] if (self._expected_store_type == TCPStore): self.assertEqual( typecast_store.host, self._expected_endpoint_host) # type: ignore[attr-defined] self.assertEqual( typecast_store.port, self._expected_endpoint_port) # type: ignore[attr-defined] if (self._expected_store_type == FileStore): if self._params.endpoint: self.assertEqual( typecast_store.path, self._expected_endpoint_file) # type: ignore[attr-defined] else: self.assertTrue( typecast_store.path.startswith( self._expected_temp_dir)) # type: ignore[attr-defined] backend.set_state(b"dummy_state") state = store.get("torch.rendezvous." + self._params.run_id) self.assertEqual(state, b64encode(b"dummy_state")) def test_create_backend_returns_backend(self) -> None: for store_type in ["tcp", "file"]: with self.subTest(store_type=store_type): self._run_test_with_store( store_type, self._assert_create_backend_returns_backend) def test_create_backend_returns_backend_if_is_host_is_false(self) -> None: store = TCPStore( # type: ignore[call-arg] # noqa: F841 self._expected_endpoint_host, self._expected_endpoint_port, is_master=True) self._params.config["is_host"] = "false" self._assert_create_backend_returns_backend() def test_create_backend_returns_backend_if_is_host_is_not_specified( self) -> None: del self._params.config["is_host"] self._assert_create_backend_returns_backend() def test_create_backend_returns_backend_if_is_host_is_not_specified_and_store_already_exists( self, ) -> None: store = TCPStore( # type: ignore[call-arg] # noqa: F841 self._expected_endpoint_host, self._expected_endpoint_port, is_master=True) del self._params.config["is_host"] self._assert_create_backend_returns_backend() def test_create_backend_returns_backend_if_endpoint_port_is_not_specified( self) -> None: self._params.endpoint = self._expected_endpoint_host self._expected_endpoint_port = 29400 self._assert_create_backend_returns_backend() def test_create_backend_returns_backend_if_endpoint_file_is_not_specified( self) -> None: self._params_filestore.endpoint = "" self._run_test_with_store("file", self._assert_create_backend_returns_backend) def test_create_backend_returns_backend_if_store_type_is_not_specified( self) -> None: del self._params.config["store_type"] self._expected_store_type = TCPStore if (not self._params.get("read_timeout")): self._expected_read_timeout = timedelta(seconds=60) self._assert_create_backend_returns_backend() def test_create_backend_returns_backend_if_read_timeout_is_not_specified( self) -> None: del self._params.config["read_timeout"] self._expected_read_timeout = timedelta(seconds=60) self._assert_create_backend_returns_backend() def test_create_backend_raises_error_if_store_is_unreachable(self) -> None: self._params.config["is_host"] = "false" self._params.config["read_timeout"] = "2" with self.assertRaisesRegex( RendezvousConnectionError, r"^The connection to the C10d store has failed. See inner exception for details.$", ): create_backend(self._params) def test_create_backend_raises_error_if_endpoint_is_invalid(self) -> None: for is_host in [True, False]: with self.subTest(is_host=is_host): self._params.config["is_host"] = str(is_host) self._params.endpoint = "dummy_endpoint" with self.assertRaisesRegex( RendezvousConnectionError, r"^The connection to the C10d store has failed. See inner exception for " r"details.$", ): create_backend(self._params) def test_create_backend_raises_error_if_store_type_is_invalid( self) -> None: self._params.config["store_type"] = "dummy_store_type" with self.assertRaisesRegex( ValueError, r"^Invalid store type given. Currently only supports file and tcp.$" ): create_backend(self._params) def test_create_backend_raises_error_if_read_timeout_is_invalid( self) -> None: for read_timeout in ["0", "-10"]: with self.subTest(read_timeout=read_timeout): self._params.config["read_timeout"] = read_timeout with self.assertRaisesRegex( ValueError, r"^The read timeout must be a positive integer.$"): create_backend(self._params) @mock.patch("tempfile.mkstemp") def test_create_backend_raises_error_if_tempfile_creation_fails( self, tempfile_mock) -> None: tempfile_mock.side_effect = OSError("test error") # Set the endpoint to empty so it defaults to creating a temp file self._params_filestore.endpoint = "" with self.assertRaisesRegex( RendezvousError, r"The file creation for C10d store has failed. See inner exception for details." ): create_backend(self._params_filestore) @mock.patch( "torch.distributed.elastic.rendezvous.c10d_rendezvous_backend.FileStore" ) def test_create_backend_raises_error_if_file_path_is_invalid( self, filestore_mock) -> None: filestore_mock.side_effect = RuntimeError("test error") self._params_filestore.endpoint = "bad file path" with self.assertRaisesRegex( RendezvousConnectionError, r"^The connection to the C10d store has failed. See inner exception for " r"details.$", ): create_backend(self._params_filestore)