示例#1
0
    def __init__(self,
                 *,
                 every_steps: Optional[int] = None,
                 every_secs: Optional[float] = None,
                 on_steps: Optional[Iterable[int]] = None,
                 callback_fn: Callable,
                 execute_async: bool = False,
                 pass_step_and_time: bool = True):
        """Initializes a new periodic Callback action.

    Args:
      every_steps: See `PeriodicAction.__init__()`.
      every_secs: See `PeriodicAction.__init__()`.
      on_steps: See `PeriodicAction.__init__()`.
      callback_fn: A callback function. It must accept `step` and `t` as
        arguments; arguments are passed by keyword.
      execute_async: if True wraps the callback into an async call.
      pass_step_and_time: if True the step and t are passed to the callback.
    """
        super().__init__(every_steps=every_steps,
                         every_secs=every_secs,
                         on_steps=on_steps)
        self._cb_results = collections.deque(maxlen=1)
        self.pass_step_and_time = pass_step_and_time
        if execute_async:
            logging.info("Callback will be executed asynchronously. "
                         "Errors are raised when they become available.")
            self._cb_fn = asynclib.Pool(callback_fn.__name__)(callback_fn)
        else:
            self._cb_fn = callback_fn
示例#2
0
  def test_queue_length(self, executor_mock):
    pool_mock = mock.Mock()
    in_flight = []

    def execute_one():
      in_flight.pop(0)()

    def submit(fn, *args, **kwargs):
      in_flight.append(lambda: fn(*args, **kwargs))

    pool_mock.submit = submit
    executor_mock.return_value = pool_mock

    pool = asynclib.Pool()

    @pool
    def noop():
      ...

    self.assertEqual(pool.queue_length, 0)
    noop()
    self.assertEqual(pool.queue_length, 1)
    noop()
    self.assertEqual(pool.queue_length, 2)
    execute_one()
    self.assertEqual(pool.queue_length, 1)
    execute_one()
    self.assertEqual(pool.queue_length, 0)
示例#3
0
 def __init__(self,
              writer: interface.MetricWriter,
              *,
              num_workers: Optional[int] = 1):
   super().__init__()
   self._writer = writer
   # By default, we have a thread pool with a single worker to ensure that
   # calls to the function are run in order (but in a background thread).
   self._num_workers = num_workers
   self._pool = asynclib.Pool(
       thread_name_prefix="AsyncWriter", max_workers=num_workers)
示例#4
0
  def test_async_execution(self):
    pool = asynclib.Pool()
    counter = 0

    @pool
    def fn(counter_increment, return_value):
      nonlocal counter
      counter += counter_increment
      return return_value

    future = fn(1, return_value=2)
    self.assertEqual(counter, 1)
    self.assertEqual(future.result(), 2)
示例#5
0
  def test_flush(self, executor_mock):
    pool_mock = mock.Mock()
    pool_mock._in_flight = None

    def execute_one():
      pool_mock._in_flight.pop(0)()

    def submit(fn, *args, **kwargs):
      pool_mock._in_flight.append(lambda: fn(*args, **kwargs))

    def create_pool(max_workers, thread_name_prefix):
      del max_workers
      del thread_name_prefix
      pool_mock._in_flight = []
      return pool_mock

    def shutdown(wait=False):
      if wait:
        while pool_mock._in_flight:
          execute_one()
      pool_mock._in_flight = None

    pool_mock.submit = submit
    executor_mock.side_effect = create_pool
    pool_mock.shutdown.side_effect = shutdown

    pool = asynclib.Pool()

    @pool
    def noop():
      ...

    self.assertEqual(pool.queue_length, 0)
    noop()
    self.assertEqual(pool.queue_length, 1)
    noop()
    pool.join()
    self.assertEqual(pool.queue_length, 0)
    noop()
    self.assertEqual(pool.queue_length, 1)
示例#6
0
  def test_reraise(self):
    pool = asynclib.Pool()

    @pool
    def error():
      raise ValueError("test")

    error()
    self.assertTrue(pool.has_errors)
    with self.assertRaisesRegex(ValueError, "test"):
      pool.join()
    self.assertFalse(pool.has_errors)

    @pool
    def noop():
      ...

    error()
    self.assertTrue(pool.has_errors)
    with self.assertRaisesRegex(ValueError, "test"):
      noop()
    self.assertFalse(pool.has_errors)

    pool.join()