Exemple #1
0
 def test_run_failure(self):
     nprocs = 4
     params = [MpParameters(fn=run_failure, args=())] * nprocs
     proc_context = start_processes(params, start_method="spawn")
     with self.assertRaises(ProcessRaisedException):
         ret_vals = proc_context.wait()
         while not ret_vals:
             ret_vals = proc_context.wait()
Exemple #2
0
 def test_termination(self):
     nprocs = 5
     params = [MpParameters(fn=run_infinite, args=())] * nprocs
     proc_context = start_processes(params, start_method="spawn")
     proc_context.terminate()
     # Processes should terminate with SIGTERM
     with self.assertRaises(Exception):
         proc_context.wait()
Exemple #3
0
 def test_wait_busy_loop(self):
     nprocs = 2
     wait_time = 10  # seconds
     params = [MpParameters(fn=run_with_wait, args=(wait_time, ))] * nprocs
     proc_context = start_processes(params, start_method="spawn")
     self.assertIsNone(proc_context.wait(1))
     while proc_context.wait(1) is None:
         pass
Exemple #4
0
 def test_run_success_no_return_func(self):
     nprocs = 4
     params = [MpParameters(fn=run_dummy, args=())] * nprocs
     proc_context = start_processes(params, start_method="spawn")
     proc_group_result = self._get_result(proc_context)
     ret_vals = proc_group_result.return_values
     self.assertEqual(4, len(ret_vals))
     for ret_val in ret_vals.values():
         self.assertEqual(None, ret_val)
Exemple #5
0
 def test_run_success_no_return_func(self):
     nprocs = 4
     params = [MpParameters(fn=run_dummy, args=())] * nprocs
     proc_context = start_processes(params, start_method="spawn")
     ret_vals = proc_context.wait()
     while not ret_vals:
         ret_vals = proc_context.wait()
     self.assertEqual(4, len(ret_vals))
     for ret_val in ret_vals.values():
         self.assertEqual(None, ret_val)
Exemple #6
0
 def test_run_success(self):
     nprocs = 4
     mult = 2
     params = [MpParameters(fn=run_compute, args=(mult,))] * nprocs
     proc_context = start_processes(params, start_method="spawn")
     proc_group_result = self._get_result(proc_context)
     ret_vals = proc_group_result.return_values
     self.assertEqual(4, len(ret_vals))
     for local_rank, ret_val in ret_vals.items():
         self.assertEqual(mult * local_rank, ret_val)
Exemple #7
0
 def test_run_success(self):
     nprocs = 4
     mult = 2
     params = [MpParameters(fn=run_compute, args=(mult, ))] * nprocs
     proc_context = start_processes(params, start_method="spawn")
     ret_vals = proc_context.wait()
     while not ret_vals:
         ret_vals = proc_context.wait()
     self.assertEqual(4, len(ret_vals))
     for local_rank, ret_val in ret_vals.items():
         self.assertEqual(mult * local_rank, ret_val)
Exemple #8
0
 def test_run_failure(self):
     os.environ["TORCHELASTIC_ERROR_FILE"] = f"{self.test_dir}/error.log"
     _process_error_handler.configure()
     nprocs = 4
     params = [MpParameters(fn=run_failure, args=())] * nprocs
     proc_context = start_processes(params, start_method="spawn")
     proc_group_result = self._get_result(proc_context)
     failed_result = proc_group_result.failure
     self.assertTrue(os.path.exists(failed_result.error_file))
     with open(failed_result.error_file, "r") as f:
         data = f.read().replace("\n", "")
     self.assertTrue("RuntimeError: Test error" in data)
     _process_error_handler.cleanup()
Exemple #9
0
def start_processes(
    params: List[MpParameters],
    start_method: str = "spawn",
):
    """
    Starts processes using torch.multiprocessing.spawn. Each process executes the same
    function. Returns the process context that contains methods over a set of processes.
    Note: All params must have the same values
    """
    proc_params = list(params)
    if len(proc_params) == 0:
        raise ValueError(
            "Params cannot be empty. Provide at least single MpParameters object"
        )
    return mp_context.start_processes(proc_params, start_method)
Exemple #10
0
 def test_run_huge_output(self):
     # python multiprocessing.queue module uses pipes and actually PipedQueues
     # This means that if a single object is greater than a pipe size
     # the writer process will block until reader process will start
     # reading the pipe.
     # This test makes a worker fn to return huge output, around ~10 MB
     nprocs = 4
     size = 200000
     params = [MpParameters(fn=fill_dict, args=(size,))] * nprocs
     proc_context = start_processes(params, start_method="spawn")
     proc_group_result = self._get_result(proc_context)
     ret_vals = proc_group_result.return_values
     self.assertEqual(4, len(ret_vals))
     for ret_val in ret_vals.values():
         self.assertEqual(size, len(ret_val))
Exemple #11
0
 def test_failure_signal(self):
     os.environ["TORCHELASTIC_ERROR_FILE"] = f"{self.test_dir}/error.log"
     _process_error_handler.configure()
     nprocs = 5
     params = [MpParameters(fn=run_failure_signal, args=())] * nprocs
     proc_context = start_processes(params, start_method="spawn")
     # Processes should terminate with SIGSEGV
     proc_group_result = proc_context.wait()
     failure = proc_group_result.failure
     self.assertTrue(os.path.exists(failure.error_file))
     self.assertEqual("SIGSEGV", failure.get_signal_name())
     with open(failure.error_file, "r") as f:
         data = f.read().replace("\n", "")
     self.assertTrue("string_at" in data)
     _process_error_handler.cleanup()