def test_run_bipolar_function(self):
     spec = self._get_worker_spec(fn=_bipolar_function, max_restarts=2)
     agent = LocalElasticAgent(spec, start_method="fork")
     with self.assertRaises(Exception):
         agent.run()
     self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state)
     self.assertEqual(0, agent._remaining_restarts)
    def test_run_sad_function(self):
        spec = self._get_worker_spec(fn=_sad_function, max_restarts=2)
        agent = LocalElasticAgent(spec, start_method="fork")
        with self.assertRaises(WorkerGroupFailureException) as cm:
            agent.run()

        excs = cm.exception.get_worker_exceptions()
        for i in range(spec.local_world_size):
            self.assertTrue(isinstance(excs[i], Exception))

        self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state)
        self.assertEqual(0, agent._remaining_restarts)
Beispiel #3
0
    def _test_run_sad_function(self):
        spec = self._get_worker_spec(fn=_sad_function, max_restarts=0)
        agent = LocalElasticAgent(spec, start_method="fork")
        group_results = agent.run()
        failed_results = group_results.failures
        self.assertEqual(spec.local_world_size, len(failed_results))
        # all ranks will have the same result
        for result in failed_results.values():
            self.assertTrue(os.path.exists(result.error_file))
            with open(result.error_file, "r") as f:
                data = f.read().replace("\n", "")
                self.assertTrue("RuntimeError: sad because i throw" in data)

        self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state)
        self.assertEqual(0, agent._remaining_restarts)
    def test_run_segv_function(self):
        expected_error_index = 0
        expected_failure = signal.SIGSEGV
        spec = self._get_worker_spec(
            fn=_fatal_signal_function,
            max_restarts=2,
            args=(expected_error_index, expected_failure),
        )
        try:
            agent = LocalElasticAgent(spec, start_method="spawn")
            with self.assertRaises(WorkerGroupFailureException) as cm:
                agent.run()
        finally:
            spec.rdzv_handler.shutdown()

        excs = cm.exception.get_worker_exceptions()
        for i in range(spec.local_world_size):
            self.assertTrue(isinstance(excs[i], WorkerSignaledException))
            self.assertEqual(expected_failure.name, excs[i].signal_name)

        self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state)
        self.assertEqual(0, agent._remaining_restarts)