def get_worker_spec( self, node_config: Conf, min_nodes=1, max_nodes=1, max_restarts=0, monitor_interval=0.01, master_addr_override: Optional[str] = None, master_port_override: Optional[int] = None, is_host=True, ): rdzv_params = RendezvousParameters( backend=self._backend, endpoint=self._endpoint, run_id=self._run_id, min_nodes=min_nodes, max_nodes=max_nodes, is_host=is_host, ) rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params) return WorkerSpec( role=node_config.role, local_world_size=node_config.local_world_size, entrypoint=node_config.entrypoint, args=node_config.args, rdzv_handler=rdzv_handler, max_restarts=max_restarts, monitor_interval=monitor_interval, redirects=node_config.redirects, tee=node_config.tee, master_addr=master_addr_override, master_port=master_port_override, )
def get_worker_spec( self, node_config: Conf, min_nodes=1, max_nodes=1, max_restarts=0, monitor_interval=0.01, ): rdzv_params = RendezvousParameters( backend="etcd", endpoint=self._etcd_server.get_endpoint(), run_id=self._run_id, min_nodes=min_nodes, max_nodes=max_nodes, ) rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params) return WorkerSpec( role=node_config.role, local_world_size=node_config.local_world_size, entrypoint=node_config.entrypoint, args=node_config.args, rdzv_handler=rdzv_handler, max_restarts=max_restarts, monitor_interval=monitor_interval, redirects=node_config.redirects, tee=node_config.tee, )
def test_worker_group_constructor(self): spec = WorkerSpec( role="test_trainer", local_world_size=4, fn=do_nothing, args=(), rdzv_handler=None, max_restarts=50, monitor_interval=1, ) worker_group = WorkerGroup(spec) self.assertEqual(WorkerState.INIT, worker_group.state) workers = worker_group.workers self.assertEqual(4, len(workers)) # validate full, consecutive local ranks self.assertSetEqual(set(range(4)), {w.local_rank for w in workers}) # global_rank, world_size are assigned after rdzv # id is assigned after starting worker (by the agent) # validate there are None for w in workers: self.assertEqual(-1, w.global_rank) self.assertEqual(-1, w.world_size) self.assertEqual(None, w.id) # rank and store are assigned after rdzv; validate that they are None self.assertIsNone(worker_group.group_rank) self.assertIsNone(worker_group.store)
def _get_worker_spec( self, max_restarts=1, monitor_interval=1.0, role="test_trainer", local_world_size=8, ): run_id = str(uuid.uuid4().int) endpoint = self._etcd_server.get_endpoint() rdzv_params = RendezvousParameters(backend="etcd", endpoint=endpoint, run_id=run_id, min_nodes=1, max_nodes=1) rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params) spec = WorkerSpec( role=role, local_world_size=local_world_size, fn=do_nothing, args=(), rdzv_handler=rdzv_handler, max_restarts=max_restarts, monitor_interval=monitor_interval, ) return spec
def _get_worker_spec( self, max_restarts=1, monitor_interval=1.0, role="test_trainer", local_world_size=8, ): run_id = str(uuid.uuid4().int) port = get_free_port() endpoint = f"127.0.0.1:{port}" rdzv_params = RendezvousParameters( backend="static", endpoint=endpoint, run_id=run_id, min_nodes=1, max_nodes=1, rank=0, ) rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params) spec = WorkerSpec( role=role, local_world_size=local_world_size, fn=do_nothing, args=(), rdzv_handler=rdzv_handler, max_restarts=max_restarts, monitor_interval=monitor_interval, ) return spec
def launch_agent( config: LaunchConfig, entrypoint: Union[Callable, str, None], args: List[Any], ) -> Dict[int, Any]: if not config.run_id: run_id = str(uuid.uuid4().int) logger.warning(f"config has no run_id, generate a new one: {run_id}") config.run_id = run_id entrypoint_name = _get_entrypoint_name(entrypoint, args) logger.info(f"Starting elastic_operator with launch configs:\n" f" entrypoint : {entrypoint_name}\n" f" min_nodes : {config.min_nodes}\n" f" max_nodes : {config.max_nodes}\n" f" nproc_per_node : {config.nproc_per_node}\n" f" run_id : {config.run_id}\n" f" rdzv_backend : {config.rdzv_backend}\n" f" rdzv_endpoint : {config.rdzv_endpoint}\n" f" rdzv_configs : {config.rdzv_configs}\n" f" max_restarts : {config.max_restarts}\n" f" monitor_interval : {config.monitor_interval}\n" f" log_dir : {config.log_dir}\n" f" metrics_cfg : {config.metrics_cfg}\n") rdzv_parameters = RendezvousParameters( backend=config.rdzv_backend, endpoint=config.rdzv_endpoint, run_id=config.run_id, min_nodes=config.min_nodes, max_nodes=config.max_nodes, **config.rdzv_configs, ) agent = None rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters) master_addr, master_port = _get_addr_and_port(rdzv_parameters) try: spec = WorkerSpec( role=config.role, local_world_size=config.nproc_per_node, entrypoint=entrypoint, args=tuple(args), rdzv_handler=rdzv_handler, max_restarts=config.max_restarts, monitor_interval=config.monitor_interval, redirects=config.redirects, tee=config.tee, master_addr=master_addr, master_port=master_port, ) cfg = metrics.MetricsConfig( config.metrics_cfg) if config.metrics_cfg else None metrics.initialize_metrics(cfg) agent = LocalElasticAgent(spec=spec, start_method=config.start_method, log_dir=config.log_dir) result = agent.run() events.record(agent.get_agent_status_event(WorkerState.SUCCEEDED)) if result.is_failed(): # ChildFailedError is treated specially by @record # if the error files for the failed children exist # @record will copy the first error (root cause) # to the error file of the launcher process. raise ChildFailedError( name=entrypoint_name, failures=result.failures, ) else: return result.return_values except ChildFailedError: raise except Exception: if agent: events.record(agent.get_agent_status_event(WorkerState.FAILED)) else: events.record(_construct_event(config)) raise finally: rdzv_handler.shutdown()
def launch_agent( config: LaunchConfig, entrypoint: Union[Callable, str, None], args: List[Any], ) -> Dict[int, Any]: if not config.run_id: run_id = str(uuid.uuid4().int) logger.warning( f"config has no run_id, generated a random run_id: {run_id}") config.run_id = run_id entrypoint_name = _get_entrypoint_name(entrypoint, args) logger.info(f"Starting elastic_operator with launch configs:\n" f" entrypoint : {entrypoint_name}\n" f" min_nodes : {config.min_nodes}\n" f" max_nodes : {config.max_nodes}\n" f" nproc_per_node : {config.nproc_per_node}\n" f" run_id : {config.run_id}\n" f" rdzv_backend : {config.rdzv_backend}\n" f" rdzv_endpoint : {config.rdzv_endpoint}\n" f" rdzv_configs : {config.rdzv_configs}\n" f" max_restarts : {config.max_restarts}\n" f" monitor_interval : {config.monitor_interval}\n" f" log_dir : {config.log_dir}\n" f" metrics_cfg : {config.metrics_cfg}\n") rdzv_parameters = RendezvousParameters( backend=config.rdzv_backend, endpoint=config.rdzv_endpoint, run_id=config.run_id, min_nodes=config.min_nodes, max_nodes=config.max_nodes, **config.rdzv_configs, ) master_addr, master_port = _get_addr_and_port(rdzv_parameters) spec = WorkerSpec( role=config.role, local_world_size=config.nproc_per_node, entrypoint=entrypoint, args=tuple(args), rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), max_restarts=config.max_restarts, monitor_interval=config.monitor_interval, redirects=config.redirects, tee=config.tee, master_addr=master_addr, master_port=master_port, ) agent = LocalElasticAgent(spec=spec, start_method=config.start_method, log_dir=config.log_dir) shutdown_rdzv = True try: metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) result = agent.run() # records that agent.run() has succeeded NOT that workers have succeeded events.record(agent.get_event_succeeded()) if result.is_failed(): # ChildFailedError is treated specially by @record # if the error files for the failed children exist # @record will copy the first error (root cause) # to the error file of the launcher process. raise ChildFailedError( name=entrypoint_name, failures=result.failures, ) return result.return_values except ChildFailedError: raise except SignalException: # when the agent dies with a signal do NOT shutdown the rdzv_handler # since this closes the rendezvous on this rdzv_id permanently and # prevents any additional scaling events shutdown_rdzv = False events.record(agent.get_event_failed()) raise except Exception: events.record(agent.get_event_failed()) raise finally: if shutdown_rdzv: spec.rdzv_handler.shutdown()
def main(): args = parse_args() current_env = os.environ.copy() for k in current_env.keys(): if "NCCL" in k: logger.info(f"{args.node_rank} {k}={current_env[k]}") if args.world_info == "None": raise ValueError("world_info can not be None") world_info = base64.urlsafe_b64decode(args.world_info) world_info = json.loads(world_info) logger.info(f"WORLD INFO DICT: {world_info}") node_list = list(world_info.keys()) args.nnodes = len(node_list) local_node = node_list[args.node_rank] local_gpu_ids = world_info[local_node] num_local_procs = len(local_gpu_ids) logger.info( f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}" ) global_rank_mapping = defaultdict(list) curr_global_rank = 0 dist_world_size = 0 for node_id in node_list: gids = world_info[node_id] dist_world_size += len(gids) for gid in gids: global_rank_mapping[node_id].append(curr_global_rank) curr_global_rank += 1 logger.info(f"global_rank_mapping={global_rank_mapping}") logger.info(f"dist_world_size={dist_world_size}") current_env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, local_gpu_ids)) logger.info(f"Setting CUDA_VISIBLE_DEVICES={current_env['CUDA_VISIBLE_DEVICES']}") # set PyTorch distributed related environmental variables current_env["MASTER_ADDR"] = args.master_addr current_env["MASTER_PORT"] = str(args.master_port) current_env["WORLD_SIZE"] = str(dist_world_size) current_env["CROSS_RANK"] = str(args.node_rank) current_env["CROSS_SIZE"] = str(args.nnodes) current_env["LOCAL_SIZE"] = str(num_local_procs) if args.save_pid: print(f"launcher pid: {os.getpid()}") pid_file = None if args.save_pid: launcher_pid = os.getpid() pid_file = os.path.join(PID_FILE_BASEPATH, f"{args.save_pid}.deepspeed") assert not os.path.isfile(pid_file), "pid file exists but shouldn't" with open(pid_file, 'w') as fd: fd.write(f"{launcher_pid}") if not is_torch_elastic_compatible(): if args.enable_elastic_training: logger.info(f"Disabling elastic training support as \ PyTorch version should be greater than 1.11.x") args.enable_elastic_training = False if os.path.exists(DLTS_POD_ENV_PATH): with open(DLTS_POD_ENV_PATH) as file: lines = file.readlines() lines = [line.rstrip() for line in lines] for line in lines: if line.startswith('export FC_TASKROLE_NAME') or line.startswith( 'export FC_TASK_INDEX'): key_val = line.split()[1] key, val = key_val.split('=') current_env[key] = val processes = [] cmd = [] if not args.enable_elastic_training: for local_rank in range(0, num_local_procs): # each process's rank dist_rank = global_rank_mapping[local_node][local_rank] current_env["RANK"] = str(dist_rank) current_env["LOCAL_RANK"] = str(local_rank) # spawn the processes cmd = [] if not args.no_python: cmd = [sys.executable, "-u"] if args.module: cmd.append("-m") else: if args.module: raise ValueError("Don't use both the '--no_python' flag" " and the '--module' flag at the same time.") cmd.append(args.training_script) # A user may not want to pass local_rank as a keyword arg so we make this optional. if not args.no_local_rank: cmd.append(f"--local_rank={local_rank}") cmd += args.training_script_args process = subprocess.Popen(cmd, env=current_env) processes.append(process) else: from ..elasticity import DSElasticAgent from torch.distributed.elastic.rendezvous import RendezvousParameters from torch.distributed.elastic.agent.server.api import WorkerSpec import torch.distributed.elastic.rendezvous.registry as rdzv_registry from torch.distributed.elastic.multiprocessing import Std if args.min_elastic_nodes == -1: args.min_elastic_nodes = 1 if args.max_elastic_nodes == -1: args.max_elastic_nodes = args.nnodes assert args.max_elastic_nodes > 0 and args.min_elastic_nodes > 0 , "Max and Min nodes should be positive" current_env["NCCL_ASYNC_ERROR_HANDLING"] = str(1) # Get config and arguments cmd = [] if not args.no_python: cmd = [sys.executable, "-u"] if args.module: cmd.append("-m") else: if args.module: raise ValueError("Don't use both the '--no_python' flag" " and the '--module' flag at the same time.") cmd.append(args.training_script) cmd += args.training_script_args cmd_args = cmd[1:] rdzv_configs: Dict[str, str] = {'timeout': 100} run_id = os.environ.get("ELASTIC_RUN_ID", ELASTIC_TRAINING_ID_DEFAULT) # Creating config for rendezvous class rdzv_parameters = RendezvousParameters(backend='c10d', endpoint=args.master_addr + ":" + str(args.master_port), run_id=run_id, min_nodes=args.min_elastic_nodes, max_nodes=args.max_elastic_nodes, **rdzv_configs) spec = WorkerSpec( role='trainer', local_world_size=num_local_procs, entrypoint=cmd[0], args=cmd[1:], rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), max_restarts=100, monitor_interval=5, redirects=Std.from_str("0"), tee=Std.from_str("0"), master_addr=None, master_port=None, ) agent = DSElasticAgent(spec, current_env) agent.run() sig_names = {2: "SIGINT", 15: "SIGTERM"} last_return_code = None def sigkill_handler(signum, frame): for process in processes: logger.info(f"Killing subprocess {process.pid}") try: terminate_process_tree(process.pid) except Exception: pass if last_return_code is not None: logger.error(f"{cmd} exits with return code = {last_return_code}") sys.exit(last_return_code) if signum in sig_names: logger.info(f"Main process received {sig_names[signum]}, exiting") if args.save_pid: if os.path.isfile(pid_file): os.remove(pid_file) sys.exit(1) # pass SIGINT/SIGTERM to children if the parent is being terminated signal.signal(signal.SIGINT, sigkill_handler) signal.signal(signal.SIGTERM, sigkill_handler) alive_processes = set(processes) while len(alive_processes): finished_processes = [] for process in alive_processes: if process.poll() is None: # the process is still running continue else: if process.returncode != 0: last_return_code = process.returncode # for sigkill_handler sigkill_handler(signal.SIGTERM, None) # not coming back else: # exited cleanly logger.info(f"Process {process.pid} exits successfully.") finished_processes.append(process) alive_processes = set(alive_processes) - set(finished_processes) time.sleep(1)