def test_torch_mp_example(self): # in practice set the max_interval to a larger value (e.g. 60 seconds) mp_queue = mp.get_context("spawn").Queue() server = timer.LocalTimerServer(mp_queue, max_interval=0.01) server.start() world_size = 8 # all processes should complete successfully # since start_process does NOT take context as parameter argument yet # this method WILL FAIL (hence the test is disabled) torch_mp.spawn(fn=_happy_function, args=(mp_queue, ), nprocs=world_size, join=True) with self.assertRaises(Exception): # torch.multiprocessing.spawn kills all sub-procs # if one of them gets killed torch_mp.spawn(fn=_stuck_function, args=(mp_queue, ), nprocs=world_size, join=True) server.stop()
def _run_example_with(self, start_method): spawn_ctx = mp.get_context(start_method) mp_queue = spawn_ctx.Queue() server = timer.LocalTimerServer(mp_queue, max_interval=0.01) server.start() world_size = 8 processes = [] for i in range(0, world_size): if i % 2 == 0: p = spawn_ctx.Process(target=_stuck_function, args=(i, mp_queue)) else: p = spawn_ctx.Process(target=_happy_function, args=(i, mp_queue)) p.start() processes.append(p) for i in range(0, world_size): p = processes[i] p.join() if i % 2 == 0: self.assertEqual(-signal.SIGKILL, p.exitcode) else: self.assertEqual(0, p.exitcode) server.stop()
def setUp(self): self.ctx = mp.get_context("spawn") self.mp_queue = self.ctx.Queue() self.max_interval = 0.01 self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval) self.server.start()
def setUp(self): self.mp_queue = mp.Queue() self.max_interval = 0.01 self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval) self.server.start()