Exemplo n.º 1
0
 def _exit_barrier(self):
     """
     Wait for ``exit_barrier_timeout`` seconds for all agents to finish
     executing their local workers (either successfully or not). This
     acts as a safety guard against user scripts that terminate at different
     times. This barrier keeps the agent process alive until all workers finish.
     """
     log.info(
         f"Local worker group finished ({self._worker_group.state}). "
         f"Waiting {self._exit_barrier_timeout} seconds for other agents to finish"
     )
     start = time.time()
     try:
         store_util.barrier(
             self._store,
             self._worker_group.group_rank,
             self._worker_group.group_world_size,
             key_prefix=_TERMINAL_STATE_SYNC_ID,
             barrier_timeout=self._exit_barrier_timeout,
         )
         log.info(
             f"Done waiting for other agents. Elapsed: {time.time() - start} seconds"
         )
     except Exception:
         log.exception(
             f"Error waiting on exit barrier. Elapsed: {time.time() - start} seconds"
         )
Exemplo n.º 2
0
    def _invoke_run(self, role: str = DEFAULT_ROLE) -> Dict[int, Any]:
        # NOTE: currently only works for a single role

        spec = self._worker_group.spec
        role = spec.role

        log.info(f"[{role}] starting workers for function: {spec.fn.__name__}")

        self._initialize_workers(self._worker_group)
        monitor_interval = spec.monitor_interval
        rdzv_handler = spec.rdzv_handler

        while True:
            assert self._worker_group.state != WorkerState.INIT
            time.sleep(monitor_interval)
            monitor_result = self._monitor_workers(self._worker_group)
            state = monitor_result.state
            self._worker_group.state = state

            put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
            put_metric(f"workers.{role}.{state.name.lower()}", 1)

            if state == WorkerState.SUCCEEDED:
                log.info(
                    f"[{role}] worker group successfully finished."
                    f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish."
                )
                try:
                    store_util.barrier(
                        self._store,
                        self._worker_group.group_rank,
                        self._worker_group.group_world_size,
                        key_prefix=_TERMINAL_STATE_SYNC_ID,
                        barrier_timeout=self._exit_barrier_timeout,
                    )
                except Exception:
                    log.exception(
                        "Local worker group succeeded, but exit barrier failed while waiting for other nodes"
                    )
                return monitor_result.ret_vals
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                if self._remaining_restarts > 0:
                    log.info(
                        f"[{role}] Worker group {state.name}. "
                        f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
                        f" will restart worker group"
                    )
                    self._remaining_restarts -= 1
                    self._restart_workers(self._worker_group)
                else:
                    self._stop_workers(self._worker_group)
                    self._worker_group.state = WorkerState.FAILED
                    msg = f"[{role}] exceeded max_restarts={spec.max_restarts}"
                    log.error(
                        f"{msg}. Waiting {self._exit_barrier_timeout} seconds for other agents to finish."
                    )
                    try:
                        store_util.barrier(
                            self._store,
                            self._worker_group.group_rank,
                            self._worker_group.group_world_size,
                            key_prefix=_TERMINAL_STATE_SYNC_ID,
                            barrier_timeout=self._exit_barrier_timeout,
                        )
                    except Exception:
                        log.exception(
                            "Local worker group failed waiting for other nodes."
                        )
                    raise WorkerGroupFailureException(msg, monitor_result.exceptions)
            elif state == WorkerState.HEALTHY:
                # membership changes do not count as retries
                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                group_rank = self._worker_group.group_rank
                if num_nodes_waiting > 0:
                    log.info(
                        f"[{role}] Detected {num_nodes_waiting} "
                        f"new nodes from group_rank={group_rank}; "
                        f"will restart worker group"
                    )
                    self._restart_workers(self._worker_group)
            else:
                raise Exception(f"[{role}] Worker group in {state.name} state")