def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: role = worker_group.spec.role worker_pids = {w.id for w in worker_group.workers} assert self._pcontext is not None pc_pids = set(self._pcontext.pids().values()) if worker_pids != pc_pids: log.error( f"[{role}] worker pids do not match process_context pids." f" Expected: {worker_pids}, actual: {pc_pids}") return RunResult(state=WorkerState.UNKNOWN) result = self._pcontext.wait(0) if result: if result.is_failed(): # map local rank failure to global rank worker_failures = {} for local_rank, failure in result.failures.items(): worker = worker_group.workers[local_rank] worker_failures[worker.global_rank] = failure return RunResult( state=WorkerState.FAILED, failures=worker_failures, ) else: # copy ret_val_queue into a map with a global ranks workers_ret_vals = {} for local_rank, ret_val in result.return_values.items(): worker = worker_group.workers[local_rank] workers_ret_vals[worker.global_rank] = ret_val return RunResult( state=WorkerState.SUCCEEDED, return_values=workers_ret_vals, ) else: return RunResult(state=WorkerState.HEALTHY)
def monres(state: WorkerState): if state == WorkerState.SUCCEEDED: return RunResult(state=state, return_values={0: 0}, failures={}) elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: pf = ProcessFailure(local_rank=0, pid=999, exitcode=1, error_file="<none>") return RunResult(state=state, return_values={}, failures={0: pf}) else: return RunResult(state=state)
def test_record_metrics_success_no_retries(self, put_metric_mock): spec = self._get_worker_spec(max_restarts=1) agent = TestAgent(spec) group_result = RunResult({}, {}) agent._record_metrics(group_result) calls = self._get_record_metrics_test_calls(success_no_retries=1) put_metric_mock.assert_has_calls(calls, any_order=True)
def test_record_metrics_failed_no_retries(self, put_metric_mock): spec = self._get_worker_spec(max_restarts=10) agent = TestAgent(spec) group_result = RunResult(state=WorkerState.FAILED, return_values={}, failures={0: 0}) agent._record_metrics(group_result) calls = self._get_record_metrics_test_calls(failed_no_retries=1) put_metric_mock.assert_has_calls(calls, any_order=True)
def test_shutdown_called(self, start_processes_mock): pcontext_mock = Mock() pcontext_mock.pids.return_value = {0: 0} start_processes_mock.return_value = pcontext_mock node_conf = Conf(entrypoint=_happy_function, local_world_size=1) spec = self.get_worker_spec(node_conf, max_restarts=0) agent = self.get_agent(spec) with patch.object(agent, "_monitor_workers") as monitor_mock: monitor_mock.return_value = RunResult(state=WorkerState.SUCCEEDED, return_values={0: 0}) agent.run("worker") pcontext_mock.close.assert_called_once()
def test_launch_shutdown(self, agent_mock_cls): agent_mock = Mock() agent_mock.run.return_value = RunResult(WorkerState.SUCCEEDED) agent_mock_cls.return_value = agent_mock rdzv_handler_mock = Mock() with patch( "torch.distributed.elastic.rendezvous.registry.get_rendezvous_handler" ) as param_mock: param_mock.return_value = rdzv_handler_mock elastic_launch( self.get_test_launch_config(1, 1, 4), sys.executable, )("-u", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}") rdzv_handler_mock.shutdown.assert_called_once()
def test_launch_shutdown(self, agent_mock_cls): nnodes = 1 nproc_per_node = 4 args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--monitor_interval=1", "--start_method=fork", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] agent_mock = Mock() agent_mock.run.return_value = RunResult(WorkerState.SUCCEEDED) agent_mock_cls.return_value = agent_mock rdzv_handler_mock = Mock() with patch( "torch.distributed.elastic.rendezvous.registry.get_rendezvous_handler" ) as param_mock: param_mock.return_value = rdzv_handler_mock launch.main(args) rdzv_handler_mock.shutdown.assert_called_once()