Ejemplo n.º 1
0
    def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
        role = worker_group.spec.role
        worker_pids = {w.id for w in worker_group.workers}
        pc_pids = set(self._pcontext.pids().values())
        if worker_pids != pc_pids:
            log.error(
                f"[{role}] worker pids do not match process_context pids."
                f" Expected: {worker_pids}, actual: {pc_pids}")
            return RunResult(state=WorkerState.UNKNOWN)

        result = self._pcontext.wait(0)
        if result:
            if result.is_failed():
                log.error(f"[{role}] Worker group failed")
                # map local rank failure to global rank
                worker_failures = {}
                for local_rank, failure in result.failures.items():
                    worker = worker_group.workers[local_rank]
                    worker_failures[worker.global_rank] = failure
                return RunResult(
                    state=WorkerState.FAILED,
                    failures=worker_failures,
                )
            else:
                # copy ret_val_queue into a map with a global ranks
                workers_ret_vals = {}
                for local_rank, ret_val in result.return_values.items():
                    worker = worker_group.workers[local_rank]
                    workers_ret_vals[worker.global_rank] = ret_val
                return RunResult(
                    state=WorkerState.SUCCEEDED,
                    return_values=workers_ret_vals,
                )
        else:
            return RunResult(state=WorkerState.HEALTHY)
Ejemplo n.º 2
0
def monres(state: WorkerState):
    if state == WorkerState.SUCCEEDED:
        return RunResult(state=state, return_values={0: 0}, failures={})
    elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
        pf = ProcessFailure(local_rank=0, pid=999, exitcode=1, error_file="<none>")
        return RunResult(state=state, return_values={}, failures={0: pf})
    else:
        return RunResult(state=state)
Ejemplo n.º 3
0
def monres(state: WorkerState):
    if state == WorkerState.SUCCEEDED:
        return RunResult(state=state, return_values={0: 0}, failures={})
    elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
        return RunResult(state=state,
                         return_values={},
                         failures={0: ProcessFailure("", 0, 0, 0, 0)})
    else:
        return RunResult(state=state)
Ejemplo n.º 4
0
 def test_record_metrics_success_no_retries(self, put_metric_mock):
     spec = self._get_worker_spec(max_restarts=1)
     agent = TestAgent(spec)
     group_result = RunResult({}, {})
     agent._record_metrics(group_result)
     calls = self._get_record_metrics_test_calls(success_no_retries=1)
     put_metric_mock.assert_has_calls(calls, any_order=True)
Ejemplo n.º 5
0
 def test_record_metrics_failed_no_retries(self, put_metric_mock):
     spec = self._get_worker_spec(max_restarts=10)
     agent = TestAgent(spec)
     group_result = RunResult(state=WorkerState.FAILED,
                              return_values={},
                              failures={0: 0})
     agent._record_metrics(group_result)
     calls = self._get_record_metrics_test_calls(failed_no_retries=1)
     put_metric_mock.assert_has_calls(calls, any_order=True)
Ejemplo n.º 6
0
    def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
        role = worker_group.spec.role

        # torch process context join() isn't really a join in the
        # traditional sense, it returns True if all the workers have
        # successfully finished, False if some/all are still running
        # and throws an Exception if some/all of them failed
        # passing timeout < 0 means check worker status and return immediately

        worker_pids = {w.id for w in worker_group.workers}
        pc_pids = set(self._process_context.pids())
        if worker_pids != pc_pids:
            log.error(
                f"[{role}] worker pids do not match process_context pids")
            return RunResult(state=WorkerState.UNKNOWN)

        proc_group_result = self._process_context.wait(timeout=1)
        if proc_group_result:
            if proc_group_result.is_failed():
                log.error(f"[{role}] Worker group failed")
                return RunResult(
                    state=WorkerState.FAILED,
                    return_values={},
                    failures={
                        w.global_rank: proc_group_result.failure
                        for w in worker_group.workers
                    },
                )
            else:
                # copy ret_val_queue into a map with a global ranks
                workers_ret_vals = {}
                for local_rank, ret_val in proc_group_result.return_values.items(
                ):
                    worker = worker_group.workers[local_rank]
                    workers_ret_vals[worker.global_rank] = ret_val
                return RunResult(
                    state=WorkerState.SUCCEEDED,
                    return_values=workers_ret_vals,
                    failures={},
                )
        else:
            return RunResult(state=WorkerState.HEALTHY)
Ejemplo n.º 7
0
 def test_shutdown_called(self, start_processes_mock):
     pcontext_mock = Mock()
     pcontext_mock.pids.return_value = {0: 0}
     start_processes_mock.return_value = pcontext_mock
     node_conf = Conf(entrypoint=_happy_function, local_world_size=1)
     spec = self.get_worker_spec(node_conf, max_restarts=0)
     agent = self.get_agent(spec)
     with patch.object(agent, "_monitor_workers") as monitor_mock:
         monitor_mock.return_value = RunResult(state=WorkerState.SUCCEEDED,
                                               return_values={0: 0})
         agent.run("worker")
     pcontext_mock.close.assert_called_once()
Ejemplo n.º 8
0
    def test_launch_shutdown(self, agent_mock_cls):
        agent_mock = Mock()
        agent_mock.run.return_value = RunResult(WorkerState.SUCCEEDED)
        agent_mock_cls.return_value = agent_mock
        rdzv_handler_mock = Mock()
        with patch(
                "torch.distributed.elastic.rendezvous.registry.get_rendezvous_handler"
        ) as param_mock:
            param_mock.return_value = rdzv_handler_mock
            elastic_launch(
                self.get_test_launch_config(1, 1, 4),
                sys.executable,
            )("-u", path("bin/test_script.py"),
              f"--touch_file_dir={self.test_dir}")

            rdzv_handler_mock.shutdown.assert_called_once()
Ejemplo n.º 9
0
 def test_launch_rdzv_shutdown(self, agent_mock_cls):
     nnodes = 1
     nproc_per_node = 4
     args = [
         f"--nnodes={nnodes}",
         f"--nproc_per_node={nproc_per_node}",
         "--monitor_interval=1",
         "--start_method=fork",
         path("bin/test_script.py"),
         f"--touch_file_dir={self.test_dir}",
     ]
     agent_mock = Mock()
     agent_mock.run.return_value = RunResult(WorkerState.SUCCEEDED)
     agent_mock_cls.return_value = agent_mock
     rdzv_handler_mock = Mock()
     with patch("torchelastic.rendezvous.registry.get_rendezvous_handler"
                ) as param_mock:
         param_mock.return_value = rdzv_handler_mock
         launch.main(args)
         rdzv_handler_mock.shutdown.assert_called_once()