Пример #1
0
    def test_multiple_computations_with_same_executor(self):
        @computations.tf_computation(tf.int32)
        def add_one(x):
            return tf.add(x, 1)

        ex = thread_delegating_executor.ThreadDelegatingExecutor(
            eager_tf_executor.EagerTFExecutor())

        async def compute():
            return await ex.create_selection(await ex.create_tuple(
                collections.OrderedDict([
                    ('a', await
                     ex.create_call(await ex.create_value(add_one), await
                                    ex.create_value(10, tf.int32)))
                ])),
                                             name='a')

        result = asyncio.get_event_loop().run_until_complete(compute())
        self.assertIsInstance(result, eager_tf_executor.EagerValue)
        self.assertEqual(result.internal_representation.numpy(), 11)

        # After this call, the ThreadDelegatingExecutor has been closed, and needs
        # to be re-initialized.
        ex.close()

        result = asyncio.get_event_loop().run_until_complete(compute())
        self.assertIsInstance(result, eager_tf_executor.EagerValue)
        self.assertEqual(result.internal_representation.numpy(), 11)
Пример #2
0
 def test_close_then_use_executor(self):
     ex = thread_delegating_executor.ThreadDelegatingExecutor(
         eager_tf_executor.EagerTFExecutor())
     ex.close()
     result = self.use_executor(ex)
     self.assertIsInstance(result, eager_tf_executor.EagerValue)
     self.assertEqual(result.internal_representation.numpy(), 11)
Пример #3
0
def _wrap_executor_in_threading_stack(ex: executor_base.Executor,
                                      use_caching: Optional[bool] = True):
    threaded_ex = thread_delegating_executor.ThreadDelegatingExecutor(ex)
    if use_caching:
        threaded_ex = caching_executor.CachingExecutor(threaded_ex)
    rre_wrapped_ex = reference_resolving_executor.ReferenceResolvingExecutor(
        threaded_ex)
    return rre_wrapped_ex
Пример #4
0
 def test_close_then_use_executor_with_cache(self):
     # Integration that use after close is compatible with the combined
     # concurrent executors and cached executors. This was broken in
     # the past due to interactions between closing, caching, and the
     # concurrent executor. See b/148288711 for context.
     ex = thread_delegating_executor.ThreadDelegatingExecutor(
         caching_executor.CachingExecutor(
             eager_tf_executor.EagerTFExecutor()))
     self.use_executor(ex)
     ex.close()
     self.use_executor(ex)
Пример #5
0
 def make_output():
     test_ex = FakeExecutor()
     executors = [
         thread_delegating_executor.ThreadDelegatingExecutor(test_ex)
         for _ in range(10)
     ]
     loop = asyncio.get_event_loop()
     vals = [ex.create_value(idx) for idx, ex in enumerate(executors)]
     results = loop.run_until_complete(asyncio.gather(*vals))
     self.assertCountEqual(list(results), list(range(10)))
     del executors
     return test_ex.output
Пример #6
0
    def test_end_to_end(self):
        @computations.tf_computation(tf.int32)
        def add_one(x):
            return tf.add(x, 1)

        executor = thread_delegating_executor.ThreadDelegatingExecutor(
            eager_tf_executor.EagerTFExecutor())

        result = _invoke(executor, add_one, 7)
        self.assertEqual(result, 8)

        # After this invocation, the ThreadDelegatingExecutor has been closed,
        # and needs to be re-initialized.

        result = _invoke(executor, add_one, 8)
        self.assertEqual(result, 9)
Пример #7
0
def _wrap_executor_in_threading_stack(ex: executor_base.Executor,
                                      support_sequence_ops: bool = False,
                                      can_resolve_references=True):
    threaded_ex = thread_delegating_executor.ThreadDelegatingExecutor(ex)
    if support_sequence_ops:
        if not can_resolve_references:
            raise ValueError(
                'Support for sequence ops requires ability to resolve references.'
            )
        threaded_ex = sequence_executor.SequenceExecutor(
            reference_resolving_executor.ReferenceResolvingExecutor(
                threaded_ex))
    if can_resolve_references:
        threaded_ex = reference_resolving_executor.ReferenceResolvingExecutor(
            threaded_ex)
    return threaded_ex
        def make_output():
            test_ex = FakeExecutor()
            executors = [
                thread_delegating_executor.ThreadDelegatingExecutor(test_ex)
                for _ in range(10)
            ]
            vals = [ex.create_value(idx) for idx, ex in enumerate(executors)]

            async def gather_coro(vals):
                return await asyncio.gather(*vals)

            results = asyncio.run(gather_coro(vals))
            results = [
                thread_value.internal_representation
                for thread_value in results
            ]
            self.assertCountEqual(results, list(range(10)))
            del executors
            return test_ex.output
Пример #9
0
def _complete_stack(ex):
    return reference_resolving_executor.ReferenceResolvingExecutor(
        caching_executor.CachingExecutor(
            thread_delegating_executor.ThreadDelegatingExecutor(ex)))
Пример #10
0
def _create_bottom_stack():
  return reference_resolving_executor.ReferenceResolvingExecutor(
      caching_executor.CachingExecutor(
          thread_delegating_executor.ThreadDelegatingExecutor(
              eager_tf_executor.EagerTFExecutor())))
def _threaded_eager_executor() -> executor_base.Executor:
    return thread_delegating_executor.ThreadDelegatingExecutor(
        eager_tf_executor.EagerTFExecutor())