def testTargetArgs(self): n = Value(ctypes.c_int, 3) coord = coordinator.Coordinator() p = coordinator.Process(coord, target=_stop_at_0, args=(coord, n)) p.start() coord.join() self.assertEqual(0, n.value)
def testStopAPI(self): coord = coordinator.Coordinator() self.assertFalse(coord.should_stop()) self.assertFalse(coord.wait_for_stop(0.01)) coord.request_stop() self.assertTrue(coord.should_stop()) self.assertTrue(coord.wait_for_stop(0.01))
def testModelSharing(self): class MyProcess(coordinator.Process): def __init__(self, coord, args=(), kwargs={}): super().__init__(coord, args=args, kwargs=kwargs) self.n = Value(ctypes.c_int, 3) def body(self, m=None): time.sleep(0.02) _stop_at_0(self._coord, self.n, m) from alf.algorithms.algorithm import Algorithm class MyAlgorithm(Algorithm): def __init__(self): super().__init__() self.register_buffer('_m', torch.tensor(0, dtype=torch.int32)) self.x = torch.tensor(0, dtype=torch.int32) def decrement(self): self._m -= 1 self.x -= 1 m = MyAlgorithm() m.share_memory() coord = coordinator.Coordinator() p = MyProcess(coord, kwargs={"m": m}) p.start() # sleep just enough for subprocess to start and before it really runs. time.sleep(0.01) # A change in parent process is reflected in child process # via share_memory m._m.fill_(-1) coord.join([p]) # Registered Buffers are shared acrosses processes: self.assertEqual(-4, m._m) # Simple tensors are not shared: self.assertEqual(0, m.x) m2 = MyAlgorithm() coord2 = coordinator.Coordinator() p2 = MyProcess(coord2, kwargs={"m": m2}) p2.start() coord2.join([p2]) # Without share_memory(), m2 in parent process is not touched: self.assertEqual(0, m2._m) # Simple tensors are not shared: self.assertEqual(0, m2.x)
def testTargetKwargs(self): n = Value(ctypes.c_int, 3) coord = coordinator.Coordinator() p = coordinator.Process(coord, target=_stop_at_0, kwargs={ "coord": coord, "n": n }) p.start() coord.join() self.assertEqual(0, n.value)
def testJoin(self): coord = coordinator.Coordinator() processes = [ Process(target=sleep_a_bit, args=(0.02, )), Process(target=sleep_a_bit, args=(0.03, )), Process(target=sleep_a_bit, args=(0.02, )) ] for t in processes: t.start() coord.join(processes) for t in processes: self.assertFalse(t.is_alive())
def testInheritedTarget(self): class MyProcess(coordinator.Process): def __init__(self, coord, args=(), kwargs={}): super().__init__(coord, args=args, kwargs=kwargs) def body(self, n=None): _stop_at_0(self._coord, n) n = Value(ctypes.c_int, 3) coord = coordinator.Coordinator() p = MyProcess(coord, kwargs={"n": n}) p.start() coord.join([p]) self.assertEqual(0, n.value)
def testJoinAllRegistered(self): coord = coordinator.Coordinator() processes = [ Process(target=sleep_a_bit, args=(0.02, )), Process(target=sleep_a_bit, args=(0.03, )), Process(target=sleep_a_bit, args=(0.02, )) ] for t in processes: t.start() for p in processes: coord.register_process(p) coord.join() for t in processes: self.assertFalse(t.is_alive())
def testStopAsync(self): coord = coordinator.Coordinator() self.assertFalse(coord.should_stop()) self.assertFalse(coord.wait_for_stop(0.1)) wait_for_stop_ev = Event() has_stopped_ev = Event() t = Process(target=stop_on_event, args=(coord, wait_for_stop_ev, has_stopped_ev)) t.start() self.assertFalse(coord.should_stop()) self.assertFalse(coord.wait_for_stop(0.01)) wait_for_stop_ev.set() has_stopped_ev.wait() self.assertTrue(coord.wait_for_stop(0.05)) self.assertTrue(coord.should_stop())
def testJoinSomeRegistered(self): coord = coordinator.Coordinator() processes = [ Process(target=sleep_a_bit, args=(0.02, )), Process(target=sleep_a_bit, args=(0.03, )), Process(target=sleep_a_bit, args=(0.02, )) ] for t in processes: t.start() coord.register_process(processes[0]) coord.register_process(processes[2]) # processes[1] is not registered we must pass it in. coord.join([processes[1]]) for t in processes: self.assertFalse(t.is_alive())
def testJoinRaiseReportExceptionUsingHandler(self): coord = coordinator.Coordinator() ev_1 = Event() ev_2 = Event() processes = [ Process(target=raise_on_event_using_context_handler, args=(coord, ev_1, ev_2, RuntimeError("First"))), Process(target=raise_on_event_using_context_handler, args=(coord, ev_2, None, RuntimeError("Too late"))) ] for t in processes: t.start() ev_1.set() # not raising with self.assertRaisesRegex(RuntimeError, "First"): coord.join(processes)
def testJoinWithoutGraceExpires(self): coord = coordinator.Coordinator() wait_for_stop_ev = Event() has_stopped_ev = Event() processes = [ Process(target=stop_on_event, args=(coord, wait_for_stop_ev, has_stopped_ev)), Process(target=sleep_a_bit, args=(10.0, )) ] for t in processes: t.daemon = True t.start() wait_for_stop_ev.set() has_stopped_ev.wait() coord.join(processes, stop_grace_period_secs=1., ignore_live_processes=True)
def TestWithGracePeriod(stop_grace_period): coord = coordinator.Coordinator() wait_for_stop_ev = Event() has_stopped_ev = Event() processes = [ Process(target=stop_on_event, args=(coord, wait_for_stop_ev, has_stopped_ev)), Process(target=sleep_a_bit, args=(10.0, )) ] for t in processes: t.daemon = True t.start() wait_for_stop_ev.set() has_stopped_ev.wait() with self.assertRaisesRegex(RuntimeError, "processes still running"): coord.join(processes, stop_grace_period_secs=stop_grace_period)
def testRequestStopRaisesIfJoined(self): coord = coordinator.Coordinator() # Join the coordinator right away. coord.join([]) reported = False with self.assertRaisesRegex(RuntimeError, "Too late"): try: raise RuntimeError("Too late") except RuntimeError as e: reported = True coord.request_stop(e) self.assertTrue(reported) # If we clear_stop the exceptions are handled normally. coord.clear_stop() try: raise RuntimeError("After clear") except RuntimeError as e: coord.request_stop(e) with self.assertRaisesRegex(RuntimeError, "After clear"): coord.join([])
def testRequestStopRaisesIfJoined_ExcInfo(self): # Same as testRequestStopRaisesIfJoined but using syc.exc_info(). coord = coordinator.Coordinator() # Join the coordinator right away. coord.join([]) reported = False with self.assertRaisesRegex(RuntimeError, "Too late"): try: raise RuntimeError("Too late") except RuntimeError: reported = True coord.request_stop(sys.exc_info()) self.assertTrue(reported) # If we clear_stop the exceptions are handled normally. coord.clear_stop() try: raise RuntimeError("After clear") except RuntimeError: coord.request_stop(sys.exc_info()) with self.assertRaisesRegex(RuntimeError, "After clear"): coord.join([])
def testJoinRaiseReportExcInfo(self): coord = coordinator.Coordinator() ev_1 = Event() ev_2 = Event() processes = [ Process(target=raise_on_event, args=(coord, ev_1, ev_2, RuntimeError("First"), False)), Process(target=raise_on_event, args=(coord, ev_2, None, RuntimeError("Too late"), False)) ] for t in processes: t.start() ev_1.set() # Being converted from threads, we don't raise exceptions from # sub processes, but we do stop all processing via the stop event. # # If we need to print sub process traceback in the future, we can use # something like this: https://stackoverflow.com/questions/19924104/python-multiprocessing-handling-child-errors-in-parent # not raising: with self.assertRaisesRegex(RuntimeError, "First"): coord.join(processes)
def testClearStopClearsExceptionToo(self): coord = coordinator.Coordinator() ev_1 = Event() processes = [ Process(target=raise_on_event, args=(coord, ev_1, None, RuntimeError("First"), True)), ] for t in processes: t.start() # not raising with self.assertRaisesRegex(RuntimeError, "First"): ev_1.set() coord.join(processes) coord.clear_stop() processes = [ Process(target=raise_on_event, args=(coord, ev_1, None, RuntimeError("Second"), True)), ] for t in processes: t.start() # not raising with self.assertRaisesRegex(RuntimeError, "Second"): ev_1.set() coord.join(processes)