def test_simple(self):
        ex = executor_test_utils.TracingExecutor(
            eager_tf_executor.EagerTFExecutor())

        @computations.tf_computation(tf.int32)
        def add_one(x):
            return tf.add(x, 1)

        async def _make():
            v1 = await ex.create_value(add_one)
            v2 = await ex.create_value(10, tf.int32)
            v3 = await ex.create_call(v1, v2)
            v4 = await ex.create_struct(structure.Struct([('foo', v3)]))
            v5 = await ex.create_selection(v4, name='foo')
            return await v5.compute()

        result = asyncio.get_event_loop().run_until_complete(_make())
        self.assertEqual(result.numpy(), 11)

        expected_trace = [('create_value', add_one, 1),
                          ('create_value', 10, tf.int32, 2),
                          ('create_call', 1, 2, 3),
                          ('create_struct', structure.Struct([('foo', 3)]), 4),
                          ('create_selection', 4, 'foo', 5),
                          ('compute', 5, result)]

        self.assertLen(ex.trace, len(expected_trace))
        for x, y in zip(ex.trace, expected_trace):
            self.assertEqual(x, y)
def test_context(rpc_mode='REQUEST_REPLY'):
    port = portpicker.pick_unused_port()
    server_pool = logging_pool.pool(max_workers=1)
    server = grpc.server(server_pool)
    server.add_insecure_port('[::]:{}'.format(port))
    target_executor = executor_stacks.local_executor_factory(
        num_clients=3).create_executor({})
    tracer = executor_test_utils.TracingExecutor(target_executor)
    service = executor_service.ExecutorService(tracer)
    executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
    server.start()
    channel = grpc.insecure_channel('localhost:{}'.format(port))
    remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode)
    executor = reference_resolving_executor.ReferenceResolvingExecutor(
        remote_exec)
    try:
        yield collections.namedtuple('_', 'executor tracer')(executor, tracer)
    finally:
        executor.close()
        tracer.close()
        try:
            channel.close()
        except AttributeError:
            pass  # Public gRPC channel doesn't support close()
        finally:
            server.stop(None)
Пример #3
0
def _make_executor_and_tracer_for_test(support_lambdas=False):
    tracer = executor_test_utils.TracingExecutor(
        eager_tf_executor.EagerTFExecutor())
    ex = caching_executor.CachingExecutor(tracer)
    if support_lambdas:
        ex = reference_resolving_executor.ReferenceResolvingExecutor(
            caching_executor.CachingExecutor(ex))
    return ex, tracer
Пример #4
0
 def _tracer_fn(cardinalities):
     tracer = executor_test_utils.TracingExecutor(
         target_factory.create_executor(cardinalities))
     tracers.append(tracer)
     return tracer
Пример #5
0
def _make_executor_and_tracer_for_test():
    tracer = executor_test_utils.TracingExecutor(
        eager_tf_executor.EagerTFExecutor())
    ex = caching_executor.CachingExecutor(tracer)
    return ex, tracer