def _exit_barrier(self): """ Wait for ``exit_barrier_timeout`` seconds for all agents to finish executing their local workers (either successfully or not). This acts as a safety guard against user scripts that terminate at different times. This barrier keeps the agent process alive until all workers finish. """ log.info( f"Local worker group finished ({self._worker_group.state}). " f"Waiting {self._exit_barrier_timeout} seconds for other agents to finish" ) start = time.time() try: store_util.barrier( self._store, self._worker_group.group_rank, self._worker_group.group_world_size, key_prefix=_TERMINAL_STATE_SYNC_ID, barrier_timeout=self._exit_barrier_timeout, ) log.info( f"Done waiting for other agents. Elapsed: {time.time() - start} seconds" ) except Exception: log.exception( f"Error waiting on exit barrier. Elapsed: {time.time() - start} seconds" )
def _invoke_run(self, role: str = DEFAULT_ROLE) -> Dict[int, Any]: # NOTE: currently only works for a single role spec = self._worker_group.spec role = spec.role log.info(f"[{role}] starting workers for function: {spec.fn.__name__}") self._initialize_workers(self._worker_group) monitor_interval = spec.monitor_interval rdzv_handler = spec.rdzv_handler while True: assert self._worker_group.state != WorkerState.INIT time.sleep(monitor_interval) monitor_result = self._monitor_workers(self._worker_group) state = monitor_result.state self._worker_group.state = state put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts) put_metric(f"workers.{role}.{state.name.lower()}", 1) if state == WorkerState.SUCCEEDED: log.info( f"[{role}] worker group successfully finished." f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish." ) try: store_util.barrier( self._store, self._worker_group.group_rank, self._worker_group.group_world_size, key_prefix=_TERMINAL_STATE_SYNC_ID, barrier_timeout=self._exit_barrier_timeout, ) except Exception: log.exception( "Local worker group succeeded, but exit barrier failed while waiting for other nodes" ) return monitor_result.ret_vals elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: if self._remaining_restarts > 0: log.info( f"[{role}] Worker group {state.name}. " f"{self._remaining_restarts}/{spec.max_restarts} attempts left;" f" will restart worker group" ) self._remaining_restarts -= 1 self._restart_workers(self._worker_group) else: self._stop_workers(self._worker_group) self._worker_group.state = WorkerState.FAILED msg = f"[{role}] exceeded max_restarts={spec.max_restarts}" log.error( f"{msg}. Waiting {self._exit_barrier_timeout} seconds for other agents to finish." ) try: store_util.barrier( self._store, self._worker_group.group_rank, self._worker_group.group_world_size, key_prefix=_TERMINAL_STATE_SYNC_ID, barrier_timeout=self._exit_barrier_timeout, ) except Exception: log.exception( "Local worker group failed waiting for other nodes." ) raise WorkerGroupFailureException(msg, monitor_result.exceptions) elif state == WorkerState.HEALTHY: # membership changes do not count as retries num_nodes_waiting = rdzv_handler.num_nodes_waiting() group_rank = self._worker_group.group_rank if num_nodes_waiting > 0: log.info( f"[{role}] Detected {num_nodes_waiting} " f"new nodes from group_rank={group_rank}; " f"will restart worker group" ) self._restart_workers(self._worker_group) else: raise Exception(f"[{role}] Worker group in {state.name} state")