def __init__(self, *, every_steps: Optional[int] = None, every_secs: Optional[float] = None, on_steps: Optional[Iterable[int]] = None, callback_fn: Callable, execute_async: bool = False, pass_step_and_time: bool = True): """Initializes a new periodic Callback action. Args: every_steps: See `PeriodicAction.__init__()`. every_secs: See `PeriodicAction.__init__()`. on_steps: See `PeriodicAction.__init__()`. callback_fn: A callback function. It must accept `step` and `t` as arguments; arguments are passed by keyword. execute_async: if True wraps the callback into an async call. pass_step_and_time: if True the step and t are passed to the callback. """ super().__init__(every_steps=every_steps, every_secs=every_secs, on_steps=on_steps) self._cb_results = collections.deque(maxlen=1) self.pass_step_and_time = pass_step_and_time if execute_async: logging.info("Callback will be executed asynchronously. " "Errors are raised when they become available.") self._cb_fn = asynclib.Pool(callback_fn.__name__)(callback_fn) else: self._cb_fn = callback_fn
def test_queue_length(self, executor_mock): pool_mock = mock.Mock() in_flight = [] def execute_one(): in_flight.pop(0)() def submit(fn, *args, **kwargs): in_flight.append(lambda: fn(*args, **kwargs)) pool_mock.submit = submit executor_mock.return_value = pool_mock pool = asynclib.Pool() @pool def noop(): ... self.assertEqual(pool.queue_length, 0) noop() self.assertEqual(pool.queue_length, 1) noop() self.assertEqual(pool.queue_length, 2) execute_one() self.assertEqual(pool.queue_length, 1) execute_one() self.assertEqual(pool.queue_length, 0)
def __init__(self, writer: interface.MetricWriter, *, num_workers: Optional[int] = 1): super().__init__() self._writer = writer # By default, we have a thread pool with a single worker to ensure that # calls to the function are run in order (but in a background thread). self._num_workers = num_workers self._pool = asynclib.Pool( thread_name_prefix="AsyncWriter", max_workers=num_workers)
def test_async_execution(self): pool = asynclib.Pool() counter = 0 @pool def fn(counter_increment, return_value): nonlocal counter counter += counter_increment return return_value future = fn(1, return_value=2) self.assertEqual(counter, 1) self.assertEqual(future.result(), 2)
def test_flush(self, executor_mock): pool_mock = mock.Mock() pool_mock._in_flight = None def execute_one(): pool_mock._in_flight.pop(0)() def submit(fn, *args, **kwargs): pool_mock._in_flight.append(lambda: fn(*args, **kwargs)) def create_pool(max_workers, thread_name_prefix): del max_workers del thread_name_prefix pool_mock._in_flight = [] return pool_mock def shutdown(wait=False): if wait: while pool_mock._in_flight: execute_one() pool_mock._in_flight = None pool_mock.submit = submit executor_mock.side_effect = create_pool pool_mock.shutdown.side_effect = shutdown pool = asynclib.Pool() @pool def noop(): ... self.assertEqual(pool.queue_length, 0) noop() self.assertEqual(pool.queue_length, 1) noop() pool.join() self.assertEqual(pool.queue_length, 0) noop() self.assertEqual(pool.queue_length, 1)
def test_reraise(self): pool = asynclib.Pool() @pool def error(): raise ValueError("test") error() self.assertTrue(pool.has_errors) with self.assertRaisesRegex(ValueError, "test"): pool.join() self.assertFalse(pool.has_errors) @pool def noop(): ... error() self.assertTrue(pool.has_errors) with self.assertRaisesRegex(ValueError, "test"): noop() self.assertFalse(pool.has_errors) pool.join()