def test_simple(self): ex = executor_test_utils.TracingExecutor( eager_tf_executor.EagerTFExecutor()) @computations.tf_computation(tf.int32) def add_one(x): return tf.add(x, 1) async def _make(): v1 = await ex.create_value(add_one) v2 = await ex.create_value(10, tf.int32) v3 = await ex.create_call(v1, v2) v4 = await ex.create_struct(structure.Struct([('foo', v3)])) v5 = await ex.create_selection(v4, name='foo') return await v5.compute() result = asyncio.get_event_loop().run_until_complete(_make()) self.assertEqual(result.numpy(), 11) expected_trace = [('create_value', add_one, 1), ('create_value', 10, tf.int32, 2), ('create_call', 1, 2, 3), ('create_struct', structure.Struct([('foo', 3)]), 4), ('create_selection', 4, 'foo', 5), ('compute', 5, result)] self.assertLen(ex.trace, len(expected_trace)) for x, y in zip(ex.trace, expected_trace): self.assertEqual(x, y)
def test_context(rpc_mode='REQUEST_REPLY'): port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) server = grpc.server(server_pool) server.add_insecure_port('[::]:{}'.format(port)) target_executor = executor_stacks.local_executor_factory( num_clients=3).create_executor({}) tracer = executor_test_utils.TracingExecutor(target_executor) service = executor_service.ExecutorService(tracer) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() channel = grpc.insecure_channel('localhost:{}'.format(port)) remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode) executor = reference_resolving_executor.ReferenceResolvingExecutor( remote_exec) try: yield collections.namedtuple('_', 'executor tracer')(executor, tracer) finally: executor.close() tracer.close() try: channel.close() except AttributeError: pass # Public gRPC channel doesn't support close() finally: server.stop(None)
def _make_executor_and_tracer_for_test(support_lambdas=False): tracer = executor_test_utils.TracingExecutor( eager_tf_executor.EagerTFExecutor()) ex = caching_executor.CachingExecutor(tracer) if support_lambdas: ex = reference_resolving_executor.ReferenceResolvingExecutor( caching_executor.CachingExecutor(ex)) return ex, tracer
def _tracer_fn(cardinalities): tracer = executor_test_utils.TracingExecutor( target_factory.create_executor(cardinalities)) tracers.append(tracer) return tracer
def _make_executor_and_tracer_for_test(): tracer = executor_test_utils.TracingExecutor( eager_tf_executor.EagerTFExecutor()) ex = caching_executor.CachingExecutor(tracer) return ex, tracer