def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: # NOTE: currently only works for a single role spec = self._worker_group.spec role = spec.role if spec.fn: log.info( f"[{role}] starting workers for function: {spec.fn.__name__}") else: log.info(f"[{role}] starting workers for cmd: {spec.cmd}") self._initialize_workers(self._worker_group) monitor_interval = spec.monitor_interval rdzv_handler = spec.rdzv_handler while True: assert self._worker_group.state != WorkerState.INIT time.sleep(monitor_interval) run_result = self._monitor_workers(self._worker_group) state = run_result.state self._worker_group.state = state put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts) put_metric(f"workers.{role}.{state.name.lower()}", 1) if state == WorkerState.SUCCEEDED: log.info( f"[{role}] worker group successfully finished." f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish." ) self._exit_barrier() return run_result elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: if self._remaining_restarts > 0: log.info( f"[{role}] Worker group {state.name}. " f"{self._remaining_restarts}/{spec.max_restarts} attempts left;" f" will restart worker group") self._remaining_restarts -= 1 self._restart_workers(self._worker_group) else: self._stop_workers(self._worker_group) self._worker_group.state = WorkerState.FAILED self._exit_barrier() return run_result elif state == WorkerState.HEALTHY: # membership changes do not count as retries num_nodes_waiting = rdzv_handler.num_nodes_waiting() group_rank = self._worker_group.group_rank if num_nodes_waiting > 0: log.info(f"[{role}] Detected {num_nodes_waiting} " f"new nodes from group_rank={group_rank}; " f"will restart worker group") self._restart_workers(self._worker_group) else: raise Exception(f"[{role}] Worker group in {state.name} state")
def _record_flakiness_metric(self, is_failed: bool = False): if is_failed: flakiness = 100.0 else: spec = self._worker_group.spec flakiness = 100.0 - 100.0 * (self._remaining_restarts + 1) / (spec.max_restarts + 1) spec = self._worker_group.spec put_metric(f"workers.{spec.role}.flakiness", int(flakiness))
def run(self, role: str = DEFAULT_ROLE) -> Dict[int, Any]: # NOTE: currently only works for a single role spec = self._worker_group.spec role = spec.role log.info(f"[{role}] starting workers for function: {spec.fn.__name__}") self._initialize_workers(self._worker_group) monitor_interval = spec.monitor_interval rdzv_handler = spec.rdzv_handler while True: assert self._worker_group.state != WorkerState.INIT time.sleep(monitor_interval) monitor_result = self._monitor_workers(self._worker_group) state = monitor_result.state self._worker_group.state = state put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts) put_metric(f"workers.{role}.{state.name.lower()}", 1) if state == WorkerState.SUCCEEDED: log.info(f"[{role}] All workers successfully finished.") return monitor_result.ret_vals elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: if self._remaining_restarts > 0: log.info( f"[{role}] Worker group {state.name}. " f"{self._remaining_restarts}/{spec.max_restarts} attempts left;" f" will restart worker group" ) self._remaining_restarts -= 1 self._restart_workers(self._worker_group) else: self._stop_workers(self._worker_group) self._worker_group.state = WorkerState.FAILED raise WorkerGroupFailureException( f"[{role}] exceeded max_restarts={spec.max_restarts}", monitor_result.exceptions, ) elif state == WorkerState.HEALTHY: # membership changes do not count as retries num_nodes_waiting = rdzv_handler.num_nodes_waiting() group_rank = self._worker_group.group_rank if num_nodes_waiting > 0: log.info( f"[{role}] Detected {num_nodes_waiting} " f"new nodes from group_rank={group_rank}; " f"will restart worker group" ) self._restart_workers(self._worker_group) else: raise Exception(f"[{role}] Worker group in {state.name} state")
def _record_flakiness_metric(self, is_failed: bool = False): if is_failed: if not isinstance(is_failed, WorkerGroupFailureException): # Only user code can contribute into flakiness score return flakiness = 100.0 else: spec = self._worker_group.spec flakiness = 100.0 - 100.0 * (self._remaining_restarts + 1) / (spec.max_restarts + 1) spec = self._worker_group.spec put_metric(f"workers.{spec.role}.flakiness", int(flakiness))
def _record_metrics(self, is_failed: bool = False): self._record_flakiness_metric(is_failed) spec = self._worker_group.spec restarts_happened = self._remaining_restarts != spec.max_restarts put_metric(f"workers.{spec.role}.run_total", 1) self._record_metric_with_condition("run_success_with_retries", not is_failed and restarts_happened) self._record_metric_with_condition( "run_success_no_retries", not is_failed and not restarts_happened) self._record_metric_with_condition("run_failed_with_retries", is_failed and restarts_happened) self._record_metric_with_condition("run_failed_no_retries", is_failed and not restarts_happened)
def _record_metric_with_condition(self, metric_name, condition): spec = self._worker_group.spec if condition: put_metric(f"workers.{spec.role}.{metric_name}", 1) else: put_metric(f"workers.{spec.role}.{metric_name}", 0)
def _invoke_run(self, role: str = DEFAULT_ROLE) -> Dict[int, Any]: # NOTE: currently only works for a single role spec = self._worker_group.spec role = spec.role log.info(f"[{role}] starting workers for function: {spec.fn.__name__}") self._initialize_workers(self._worker_group) monitor_interval = spec.monitor_interval rdzv_handler = spec.rdzv_handler while True: assert self._worker_group.state != WorkerState.INIT time.sleep(monitor_interval) monitor_result = self._monitor_workers(self._worker_group) state = monitor_result.state self._worker_group.state = state put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts) put_metric(f"workers.{role}.{state.name.lower()}", 1) if state == WorkerState.SUCCEEDED: log.info( f"[{role}] worker group successfully finished." f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish." ) try: store_util.barrier( self._store, self._worker_group.group_rank, self._worker_group.group_world_size, key_prefix=_TERMINAL_STATE_SYNC_ID, barrier_timeout=self._exit_barrier_timeout, ) except Exception: log.exception( "Local worker group succeeded, but exit barrier failed while waiting for other nodes" ) return monitor_result.ret_vals elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: if self._remaining_restarts > 0: log.info( f"[{role}] Worker group {state.name}. " f"{self._remaining_restarts}/{spec.max_restarts} attempts left;" f" will restart worker group" ) self._remaining_restarts -= 1 self._restart_workers(self._worker_group) else: self._stop_workers(self._worker_group) self._worker_group.state = WorkerState.FAILED msg = f"[{role}] exceeded max_restarts={spec.max_restarts}" log.error( f"{msg}. Waiting {self._exit_barrier_timeout} seconds for other agents to finish." ) try: store_util.barrier( self._store, self._worker_group.group_rank, self._worker_group.group_world_size, key_prefix=_TERMINAL_STATE_SYNC_ID, barrier_timeout=self._exit_barrier_timeout, ) except Exception: log.exception( "Local worker group failed waiting for other nodes." ) raise WorkerGroupFailureException(msg, monitor_result.exceptions) elif state == WorkerState.HEALTHY: # membership changes do not count as retries num_nodes_waiting = rdzv_handler.num_nodes_waiting() group_rank = self._worker_group.group_rank if num_nodes_waiting > 0: log.info( f"[{role}] Detected {num_nodes_waiting} " f"new nodes from group_rank={group_rank}; " f"will restart worker group" ) self._restart_workers(self._worker_group) else: raise Exception(f"[{role}] Worker group in {state.name} state")