def test_multiple_computations_with_same_executor(self): @computations.tf_computation(tf.int32) def add_one(x): return tf.add(x, 1) ex = thread_delegating_executor.ThreadDelegatingExecutor( eager_tf_executor.EagerTFExecutor()) async def compute(): return await ex.create_selection(await ex.create_tuple( collections.OrderedDict([ ('a', await ex.create_call(await ex.create_value(add_one), await ex.create_value(10, tf.int32))) ])), name='a') result = asyncio.get_event_loop().run_until_complete(compute()) self.assertIsInstance(result, eager_tf_executor.EagerValue) self.assertEqual(result.internal_representation.numpy(), 11) # After this call, the ThreadDelegatingExecutor has been closed, and needs # to be re-initialized. ex.close() result = asyncio.get_event_loop().run_until_complete(compute()) self.assertIsInstance(result, eager_tf_executor.EagerValue) self.assertEqual(result.internal_representation.numpy(), 11)
def test_close_then_use_executor(self): ex = thread_delegating_executor.ThreadDelegatingExecutor( eager_tf_executor.EagerTFExecutor()) ex.close() result = self.use_executor(ex) self.assertIsInstance(result, eager_tf_executor.EagerValue) self.assertEqual(result.internal_representation.numpy(), 11)
def _wrap_executor_in_threading_stack(ex: executor_base.Executor, use_caching: Optional[bool] = True): threaded_ex = thread_delegating_executor.ThreadDelegatingExecutor(ex) if use_caching: threaded_ex = caching_executor.CachingExecutor(threaded_ex) rre_wrapped_ex = reference_resolving_executor.ReferenceResolvingExecutor( threaded_ex) return rre_wrapped_ex
def test_close_then_use_executor_with_cache(self): # Integration that use after close is compatible with the combined # concurrent executors and cached executors. This was broken in # the past due to interactions between closing, caching, and the # concurrent executor. See b/148288711 for context. ex = thread_delegating_executor.ThreadDelegatingExecutor( caching_executor.CachingExecutor( eager_tf_executor.EagerTFExecutor())) self.use_executor(ex) ex.close() self.use_executor(ex)
def make_output(): test_ex = FakeExecutor() executors = [ thread_delegating_executor.ThreadDelegatingExecutor(test_ex) for _ in range(10) ] loop = asyncio.get_event_loop() vals = [ex.create_value(idx) for idx, ex in enumerate(executors)] results = loop.run_until_complete(asyncio.gather(*vals)) self.assertCountEqual(list(results), list(range(10))) del executors return test_ex.output
def test_end_to_end(self): @computations.tf_computation(tf.int32) def add_one(x): return tf.add(x, 1) executor = thread_delegating_executor.ThreadDelegatingExecutor( eager_tf_executor.EagerTFExecutor()) result = _invoke(executor, add_one, 7) self.assertEqual(result, 8) # After this invocation, the ThreadDelegatingExecutor has been closed, # and needs to be re-initialized. result = _invoke(executor, add_one, 8) self.assertEqual(result, 9)
def _wrap_executor_in_threading_stack(ex: executor_base.Executor, support_sequence_ops: bool = False, can_resolve_references=True): threaded_ex = thread_delegating_executor.ThreadDelegatingExecutor(ex) if support_sequence_ops: if not can_resolve_references: raise ValueError( 'Support for sequence ops requires ability to resolve references.' ) threaded_ex = sequence_executor.SequenceExecutor( reference_resolving_executor.ReferenceResolvingExecutor( threaded_ex)) if can_resolve_references: threaded_ex = reference_resolving_executor.ReferenceResolvingExecutor( threaded_ex) return threaded_ex
def make_output(): test_ex = FakeExecutor() executors = [ thread_delegating_executor.ThreadDelegatingExecutor(test_ex) for _ in range(10) ] vals = [ex.create_value(idx) for idx, ex in enumerate(executors)] async def gather_coro(vals): return await asyncio.gather(*vals) results = asyncio.run(gather_coro(vals)) results = [ thread_value.internal_representation for thread_value in results ] self.assertCountEqual(results, list(range(10))) del executors return test_ex.output
def _complete_stack(ex): return reference_resolving_executor.ReferenceResolvingExecutor( caching_executor.CachingExecutor( thread_delegating_executor.ThreadDelegatingExecutor(ex)))
def _create_bottom_stack(): return reference_resolving_executor.ReferenceResolvingExecutor( caching_executor.CachingExecutor( thread_delegating_executor.ThreadDelegatingExecutor( eager_tf_executor.EagerTFExecutor())))
def _threaded_eager_executor() -> executor_base.Executor: return thread_delegating_executor.ThreadDelegatingExecutor( eager_tf_executor.EagerTFExecutor())