def test_run_failure(self): nprocs = 4 params = [MpParameters(fn=run_failure, args=())] * nprocs proc_context = start_processes(params, start_method="spawn") with self.assertRaises(ProcessRaisedException): ret_vals = proc_context.wait() while not ret_vals: ret_vals = proc_context.wait()
def test_termination(self): nprocs = 5 params = [MpParameters(fn=run_infinite, args=())] * nprocs proc_context = start_processes(params, start_method="spawn") proc_context.terminate() # Processes should terminate with SIGTERM with self.assertRaises(Exception): proc_context.wait()
def test_wait_busy_loop(self): nprocs = 2 wait_time = 10 # seconds params = [MpParameters(fn=run_with_wait, args=(wait_time, ))] * nprocs proc_context = start_processes(params, start_method="spawn") self.assertIsNone(proc_context.wait(1)) while proc_context.wait(1) is None: pass
def test_run_success_no_return_func(self): nprocs = 4 params = [MpParameters(fn=run_dummy, args=())] * nprocs proc_context = start_processes(params, start_method="spawn") proc_group_result = self._get_result(proc_context) ret_vals = proc_group_result.return_values self.assertEqual(4, len(ret_vals)) for ret_val in ret_vals.values(): self.assertEqual(None, ret_val)
def test_run_success_no_return_func(self): nprocs = 4 params = [MpParameters(fn=run_dummy, args=())] * nprocs proc_context = start_processes(params, start_method="spawn") ret_vals = proc_context.wait() while not ret_vals: ret_vals = proc_context.wait() self.assertEqual(4, len(ret_vals)) for ret_val in ret_vals.values(): self.assertEqual(None, ret_val)
def test_run_success(self): nprocs = 4 mult = 2 params = [MpParameters(fn=run_compute, args=(mult,))] * nprocs proc_context = start_processes(params, start_method="spawn") proc_group_result = self._get_result(proc_context) ret_vals = proc_group_result.return_values self.assertEqual(4, len(ret_vals)) for local_rank, ret_val in ret_vals.items(): self.assertEqual(mult * local_rank, ret_val)
def test_run_success(self): nprocs = 4 mult = 2 params = [MpParameters(fn=run_compute, args=(mult, ))] * nprocs proc_context = start_processes(params, start_method="spawn") ret_vals = proc_context.wait() while not ret_vals: ret_vals = proc_context.wait() self.assertEqual(4, len(ret_vals)) for local_rank, ret_val in ret_vals.items(): self.assertEqual(mult * local_rank, ret_val)
def test_run_failure(self): os.environ["TORCHELASTIC_ERROR_FILE"] = f"{self.test_dir}/error.log" _process_error_handler.configure() nprocs = 4 params = [MpParameters(fn=run_failure, args=())] * nprocs proc_context = start_processes(params, start_method="spawn") proc_group_result = self._get_result(proc_context) failed_result = proc_group_result.failure self.assertTrue(os.path.exists(failed_result.error_file)) with open(failed_result.error_file, "r") as f: data = f.read().replace("\n", "") self.assertTrue("RuntimeError: Test error" in data) _process_error_handler.cleanup()
def start_processes( params: List[MpParameters], start_method: str = "spawn", ): """ Starts processes using torch.multiprocessing.spawn. Each process executes the same function. Returns the process context that contains methods over a set of processes. Note: All params must have the same values """ proc_params = list(params) if len(proc_params) == 0: raise ValueError( "Params cannot be empty. Provide at least single MpParameters object" ) return mp_context.start_processes(proc_params, start_method)
def test_run_huge_output(self): # python multiprocessing.queue module uses pipes and actually PipedQueues # This means that if a single object is greater than a pipe size # the writer process will block until reader process will start # reading the pipe. # This test makes a worker fn to return huge output, around ~10 MB nprocs = 4 size = 200000 params = [MpParameters(fn=fill_dict, args=(size,))] * nprocs proc_context = start_processes(params, start_method="spawn") proc_group_result = self._get_result(proc_context) ret_vals = proc_group_result.return_values self.assertEqual(4, len(ret_vals)) for ret_val in ret_vals.values(): self.assertEqual(size, len(ret_val))
def test_failure_signal(self): os.environ["TORCHELASTIC_ERROR_FILE"] = f"{self.test_dir}/error.log" _process_error_handler.configure() nprocs = 5 params = [MpParameters(fn=run_failure_signal, args=())] * nprocs proc_context = start_processes(params, start_method="spawn") # Processes should terminate with SIGSEGV proc_group_result = proc_context.wait() failure = proc_group_result.failure self.assertTrue(os.path.exists(failure.error_file)) self.assertEqual("SIGSEGV", failure.get_signal_name()) with open(failure.error_file, "r") as f: data = f.read().replace("\n", "") self.assertTrue("string_at" in data) _process_error_handler.cleanup()