Ejemplo n.º 1
0
  def test_with_block(self):
    ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor())
    loop = asyncio.get_event_loop()

    f_type = computation_types.FunctionType(tf.int32, tf.int32)
    a = building_blocks.Reference(
        'a', computation_types.NamedTupleType([('f', f_type), ('x', tf.int32)]))
    ret = building_blocks.Block([('f', building_blocks.Selection(a, name='f')),
                                 ('x', building_blocks.Selection(a, name='x'))],
                                building_blocks.Call(
                                    building_blocks.Reference('f', f_type),
                                    building_blocks.Call(
                                        building_blocks.Reference('f', f_type),
                                        building_blocks.Reference(
                                            'x', tf.int32))))
    comp = building_blocks.Lambda(a.name, a.type_signature, ret)

    @computations.tf_computation(tf.int32)
    def add_one(x):
      return x + 1

    v1 = loop.run_until_complete(
        ex.create_value(comp.proto, comp.type_signature))
    v2 = loop.run_until_complete(ex.create_value(add_one))
    v3 = loop.run_until_complete(ex.create_value(10, tf.int32))
    v4 = loop.run_until_complete(
        ex.create_tuple(anonymous_tuple.AnonymousTuple([('f', v2), ('x', v3)])))
    v5 = loop.run_until_complete(ex.create_call(v1, v4))
    result = loop.run_until_complete(v5.compute())
    self.assertEqual(result.numpy(), 12)
Ejemplo n.º 2
0
def test_context(rpc_mode='REQUEST_REPLY'):
  port = portpicker.pick_unused_port()
  server_pool = logging_pool.pool(max_workers=1)
  server = grpc.server(server_pool)
  server.add_insecure_port('[::]:{}'.format(port))
  target_executor = executor_stacks.local_executor_factory(
      num_clients=3).create_executor({})
  tracer = executor_test_utils.TracingExecutor(target_executor)
  service = executor_service.ExecutorService(tracer)
  executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
  server.start()
  channel = grpc.insecure_channel('localhost:{}'.format(port))
  remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode)
  executor = lambda_executor.LambdaExecutor(remote_exec)
  set_default_executor.set_default_executor(
      executor_factory.ExecutorFactoryImpl(lambda _: executor))
  try:
    yield collections.namedtuple('_', 'executor tracer')(executor, tracer)
  finally:
    set_default_executor.set_default_executor()
    try:
      channel.close()
    except AttributeError:
      pass  # Public gRPC channel doesn't support close()
    finally:
      server.stop(None)
Ejemplo n.º 3
0
  def test_with_federated_map_and_broadcast(self):
    eager_ex = eager_executor.EagerExecutor()
    federated_ex = federated_executor.FederatedExecutor({
        None: eager_ex,
        placement_literals.SERVER: eager_ex,
        placement_literals.CLIENTS: [eager_ex for _ in range(3)]
    })
    ex = lambda_executor.LambdaExecutor(federated_ex)
    loop = asyncio.get_event_loop()

    @computations.tf_computation(tf.int32)
    def add_one(x):
      return x + 1

    @computations.federated_computation(type_factory.at_server(tf.int32))
    def comp(x):
      return intrinsics.federated_map(add_one,
                                      intrinsics.federated_broadcast(x))

    v1 = loop.run_until_complete(ex.create_value(comp))
    v2 = loop.run_until_complete(
        ex.create_value(10, type_factory.at_server(tf.int32)))
    v3 = loop.run_until_complete(ex.create_call(v1, v2))
    result = loop.run_until_complete(v3.compute())
    self.assertCountEqual([x.numpy() for x in result], [11, 11, 11])
Ejemplo n.º 4
0
  def test_with_no_arg_tf_comp_in_no_arg_fed_comp(self):
    ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor())
    loop = asyncio.get_event_loop()

    @computations.federated_computation
    def comp():
      return 10

    v1 = loop.run_until_complete(ex.create_value(comp))
    v2 = loop.run_until_complete(ex.create_call(v1))
    result = loop.run_until_complete(v2.compute())
    self.assertEqual(result.numpy(), 10)
Ejemplo n.º 5
0
  def test_clear_failure_with_mismatched_types_in_create_call(self):
    ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor())
    loop = asyncio.get_event_loop()

    @computations.federated_computation(tf.float32)
    def comp(x):
      return x

    v1 = loop.run_until_complete(ex.create_value(comp))
    v2 = loop.run_until_complete(ex.create_value(10, tf.int32))
    with self.assertRaisesRegex(TypeError, 'incompatible'):
      loop.run_until_complete(ex.create_call(v1, v2))
Ejemplo n.º 6
0
def _make_test_executor(
    num_clients=1,
    use_lambda_executor=False,
) -> federated_executor.FederatedExecutor:
    bottom_ex = eager_executor.EagerExecutor()
    if use_lambda_executor:
        bottom_ex = lambda_executor.LambdaExecutor(bottom_ex)
    return federated_executor.FederatedExecutor({
        placements.SERVER:
        bottom_ex,
        placements.CLIENTS: [bottom_ex for _ in range(num_clients)],
        None:
        bottom_ex
    })
Ejemplo n.º 7
0
  def test_create_selection_with_tuples(self):
    ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor())
    loop = asyncio.get_event_loop()

    v1 = loop.run_until_complete(ex.create_value(10, tf.int32))
    v2 = loop.run_until_complete(ex.create_value(20, tf.int32))
    v3 = loop.run_until_complete(
        ex.create_tuple(
            anonymous_tuple.AnonymousTuple([(None, v1), (None, v2)])))
    v4 = loop.run_until_complete(ex.create_selection(v3, index=0))
    v5 = loop.run_until_complete(ex.create_selection(v3, index=1))
    result0 = loop.run_until_complete(v4.compute())
    result1 = loop.run_until_complete(v5.compute())
    self.assertEqual(result0.numpy(), 10)
    self.assertEqual(result1.numpy(), 20)
Ejemplo n.º 8
0
  def test_with_tuples(self):
    ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor())
    loop = asyncio.get_event_loop()

    @computations.tf_computation(tf.int32, tf.int32)
    def add_numbers(x, y):
      return x + y

    @computations.federated_computation
    def comp():
      return add_numbers(10, 20)

    v1 = loop.run_until_complete(ex.create_value(comp))
    v2 = loop.run_until_complete(ex.create_call(v1))
    result = loop.run_until_complete(v2.compute())
    self.assertEqual(result.numpy(), 30)
Ejemplo n.º 9
0
  def test_with_one_arg_tf_comp_in_one_arg_fed_comp(self):
    ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor())
    loop = asyncio.get_event_loop()

    @computations.tf_computation(tf.int32)
    def add_one(x):
      return x + 1

    @computations.federated_computation(tf.int32)
    def comp(x):
      return add_one(add_one(x))

    v1 = loop.run_until_complete(ex.create_value(comp))
    v2 = loop.run_until_complete(ex.create_value(10, tf.int32))
    v3 = loop.run_until_complete(ex.create_call(v1, v2))
    result = loop.run_until_complete(v3.compute())
    self.assertEqual(result.numpy(), 12)
Ejemplo n.º 10
0
  def test_with_functional_parameter(self):
    ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor())
    loop = asyncio.get_event_loop()

    @computations.tf_computation(tf.int32)
    def add_one(x):
      return x + 1

    @computations.federated_computation(
        computation_types.FunctionType(tf.int32, tf.int32), tf.int32)
    def comp(f, x):
      return f(f(x))

    v1 = loop.run_until_complete(ex.create_value(comp))
    v2 = loop.run_until_complete(ex.create_value(add_one))
    v3 = loop.run_until_complete(ex.create_value(10, tf.int32))
    v4 = loop.run_until_complete(
        ex.create_tuple(
            anonymous_tuple.AnonymousTuple([(None, v2), (None, v3)])))
    v5 = loop.run_until_complete(ex.create_call(v1, v4))
    result = loop.run_until_complete(v5.compute())
    self.assertEqual(result.numpy(), 12)
Ejemplo n.º 11
0
  def test_with_one_arg_tf_comp_in_two_arg_fed_comp(self):
    ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor())
    loop = asyncio.get_event_loop()

    @computations.tf_computation(tf.int32, tf.int32)
    def add_numbers(x, y):
      return x + y

    @computations.federated_computation(tf.int32, tf.int32)
    def comp(x, y):
      return add_numbers(x, x), add_numbers(x, y), add_numbers(y, y)

    v1 = loop.run_until_complete(ex.create_value(comp))
    v2 = loop.run_until_complete(ex.create_value(10, tf.int32))
    v3 = loop.run_until_complete(ex.create_value(20, tf.int32))
    v4 = loop.run_until_complete(
        ex.create_tuple(
            anonymous_tuple.AnonymousTuple([(None, v2), (None, v3)])))
    v5 = loop.run_until_complete(ex.create_call(v1, v4))
    result = loop.run_until_complete(v5.compute())
    self.assertEqual(
        str(anonymous_tuple.map_structure(lambda x: x.numpy(), result)),
        '<20,30,40>')
Ejemplo n.º 12
0
  def test_with_federated_map(self):
    eager_ex = eager_executor.EagerExecutor()
    federated_ex = federated_executor.FederatedExecutor({
        None: eager_ex,
        placement_literals.SERVER: eager_ex
    })
    ex = lambda_executor.LambdaExecutor(federated_ex)
    loop = asyncio.get_event_loop()

    @computations.tf_computation(tf.int32)
    def add_one(x):
      return x + 1

    @computations.federated_computation(type_factory.at_server(tf.int32))
    def comp(x):
      return intrinsics.federated_map(add_one, x)

    v1 = loop.run_until_complete(ex.create_value(comp))
    v2 = loop.run_until_complete(
        ex.create_value(10, type_factory.at_server(tf.int32)))
    v3 = loop.run_until_complete(ex.create_call(v1, v2))
    result = loop.run_until_complete(v3.compute())
    self.assertEqual(result.numpy(), 11)
Ejemplo n.º 13
0
def _create_bottom_stack():
    return lambda_executor.LambdaExecutor(
        caching_executor.CachingExecutor(
            concurrent_executor.ConcurrentExecutor(
                eager_executor.EagerExecutor())))
Ejemplo n.º 14
0
def _make_executor_and_tracer_for_test(support_lambdas=False):
  tracer = executor_test_utils.TracingExecutor(eager_executor.EagerExecutor())
  ex = caching_executor.CachingExecutor(tracer)
  if support_lambdas:
    ex = lambda_executor.LambdaExecutor(caching_executor.CachingExecutor(ex))
  return ex, tracer
Ejemplo n.º 15
0
def _complete_stack(ex):
    return lambda_executor.LambdaExecutor(
        caching_executor.CachingExecutor(
            concurrent_executor.ConcurrentExecutor(ex)))
Ejemplo n.º 16
0
 def test_runs_tf(self):
   executor_test_utils.test_runs_tf(
       self, lambda_executor.LambdaExecutor(eager_executor.EagerExecutor()))
Ejemplo n.º 17
0
def _create_middle_stack(children):
    return lambda_executor.LambdaExecutor(
        caching_executor.CachingExecutor(
            composite_executor.CompositeExecutor(_create_bottom_stack(),
                                                 children)))