예제 #1
0
def test_context(rpc_mode='REQUEST_REPLY'):
  port = portpicker.pick_unused_port()
  server_pool = logging_pool.pool(max_workers=1)
  server = grpc.server(server_pool)
  server.add_insecure_port('[::]:{}'.format(port))
  target_executor = executor_stacks.local_executor_factory(
      num_clients=3).create_executor({})
  tracer = executor_test_utils.TracingExecutor(target_executor)
  service = executor_service.ExecutorService(tracer)
  executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
  server.start()
  channel = grpc.insecure_channel('localhost:{}'.format(port))
  remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode)
  executor = lambda_executor.LambdaExecutor(remote_exec)
  set_default_executor.set_default_executor(
      executor_factory.ExecutorFactoryImpl(lambda _: executor))
  try:
    yield collections.namedtuple('_', 'executor tracer')(executor, tracer)
  finally:
    set_default_executor.set_default_executor()
    try:
      channel.close()
    except AttributeError:
      pass  # Public gRPC channel doesn't support close()
    finally:
      server.stop(None)
예제 #2
0
  def test_simple(self):
    ex = executor_test_utils.TracingExecutor(eager_executor.EagerExecutor())

    @computations.tf_computation(tf.int32)
    def add_one(x):
      return tf.add(x, 1)

    async def _make():
      v1 = await ex.create_value(add_one)
      v2 = await ex.create_value(10, tf.int32)
      v3 = await ex.create_call(v1, v2)
      v4 = await ex.create_tuple(anonymous_tuple.AnonymousTuple([('foo', v3)]))
      v5 = await ex.create_selection(v4, name='foo')
      return await v5.compute()

    result = asyncio.get_event_loop().run_until_complete(_make())
    self.assertEqual(result.numpy(), 11)

    expected_trace = [('create_value', add_one, 1),
                      ('create_value', 10, tf.int32, 2),
                      ('create_call', 1, 2, 3),
                      ('create_tuple',
                       anonymous_tuple.AnonymousTuple([('foo', 3)]), 4),
                      ('create_selection', 4, 'foo', 5), ('compute', 5, result)]

    self.assertLen(ex.trace, len(expected_trace))
    for x, y in zip(ex.trace, expected_trace):
      self.assertCountEqual(x, y)
예제 #3
0
def _make_executor_and_tracer_for_test(support_lambdas=False):
    tracer = executor_test_utils.TracingExecutor(
        eager_executor.EagerExecutor())
    ex = caching_executor.CachingExecutor(tracer)
    if support_lambdas:
        ex = lambda_executor.LambdaExecutor(
            caching_executor.CachingExecutor(ex))
    return ex, tracer
예제 #4
0
def _make_executor_and_tracer_for_test(support_lambdas=False):
    tracer = executor_test_utils.TracingExecutor(
        eager_tf_executor.EagerTFExecutor())
    ex = caching_executor.CachingExecutor(tracer)
    if support_lambdas:
        ex = reference_resolving_executor.ReferenceResolvingExecutor(
            caching_executor.CachingExecutor(ex))
    return ex, tracer