def test_with_block(self): ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor()) loop = asyncio.get_event_loop() f_type = computation_types.FunctionType(tf.int32, tf.int32) a = building_blocks.Reference( 'a', computation_types.NamedTupleType([('f', f_type), ('x', tf.int32)])) ret = building_blocks.Block([('f', building_blocks.Selection(a, name='f')), ('x', building_blocks.Selection(a, name='x'))], building_blocks.Call( building_blocks.Reference('f', f_type), building_blocks.Call( building_blocks.Reference('f', f_type), building_blocks.Reference( 'x', tf.int32)))) comp = building_blocks.Lambda(a.name, a.type_signature, ret) @computations.tf_computation(tf.int32) def add_one(x): return x + 1 v1 = loop.run_until_complete( ex.create_value(comp.proto, comp.type_signature)) v2 = loop.run_until_complete(ex.create_value(add_one)) v3 = loop.run_until_complete(ex.create_value(10, tf.int32)) v4 = loop.run_until_complete( ex.create_tuple(anonymous_tuple.AnonymousTuple([('f', v2), ('x', v3)]))) v5 = loop.run_until_complete(ex.create_call(v1, v4)) result = loop.run_until_complete(v5.compute()) self.assertEqual(result.numpy(), 12)
def test_context(rpc_mode='REQUEST_REPLY'): port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) server = grpc.server(server_pool) server.add_insecure_port('[::]:{}'.format(port)) target_executor = executor_stacks.local_executor_factory( num_clients=3).create_executor({}) tracer = executor_test_utils.TracingExecutor(target_executor) service = executor_service.ExecutorService(tracer) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() channel = grpc.insecure_channel('localhost:{}'.format(port)) remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode) executor = lambda_executor.LambdaExecutor(remote_exec) set_default_executor.set_default_executor( executor_factory.ExecutorFactoryImpl(lambda _: executor)) try: yield collections.namedtuple('_', 'executor tracer')(executor, tracer) finally: set_default_executor.set_default_executor() try: channel.close() except AttributeError: pass # Public gRPC channel doesn't support close() finally: server.stop(None)
def test_with_federated_map_and_broadcast(self): eager_ex = eager_executor.EagerExecutor() federated_ex = federated_executor.FederatedExecutor({ None: eager_ex, placement_literals.SERVER: eager_ex, placement_literals.CLIENTS: [eager_ex for _ in range(3)] }) ex = lambda_executor.LambdaExecutor(federated_ex) loop = asyncio.get_event_loop() @computations.tf_computation(tf.int32) def add_one(x): return x + 1 @computations.federated_computation(type_factory.at_server(tf.int32)) def comp(x): return intrinsics.federated_map(add_one, intrinsics.federated_broadcast(x)) v1 = loop.run_until_complete(ex.create_value(comp)) v2 = loop.run_until_complete( ex.create_value(10, type_factory.at_server(tf.int32))) v3 = loop.run_until_complete(ex.create_call(v1, v2)) result = loop.run_until_complete(v3.compute()) self.assertCountEqual([x.numpy() for x in result], [11, 11, 11])
def test_with_no_arg_tf_comp_in_no_arg_fed_comp(self): ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor()) loop = asyncio.get_event_loop() @computations.federated_computation def comp(): return 10 v1 = loop.run_until_complete(ex.create_value(comp)) v2 = loop.run_until_complete(ex.create_call(v1)) result = loop.run_until_complete(v2.compute()) self.assertEqual(result.numpy(), 10)
def test_clear_failure_with_mismatched_types_in_create_call(self): ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor()) loop = asyncio.get_event_loop() @computations.federated_computation(tf.float32) def comp(x): return x v1 = loop.run_until_complete(ex.create_value(comp)) v2 = loop.run_until_complete(ex.create_value(10, tf.int32)) with self.assertRaisesRegex(TypeError, 'incompatible'): loop.run_until_complete(ex.create_call(v1, v2))
def _make_test_executor( num_clients=1, use_lambda_executor=False, ) -> federated_executor.FederatedExecutor: bottom_ex = eager_executor.EagerExecutor() if use_lambda_executor: bottom_ex = lambda_executor.LambdaExecutor(bottom_ex) return federated_executor.FederatedExecutor({ placements.SERVER: bottom_ex, placements.CLIENTS: [bottom_ex for _ in range(num_clients)], None: bottom_ex })
def test_create_selection_with_tuples(self): ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor()) loop = asyncio.get_event_loop() v1 = loop.run_until_complete(ex.create_value(10, tf.int32)) v2 = loop.run_until_complete(ex.create_value(20, tf.int32)) v3 = loop.run_until_complete( ex.create_tuple( anonymous_tuple.AnonymousTuple([(None, v1), (None, v2)]))) v4 = loop.run_until_complete(ex.create_selection(v3, index=0)) v5 = loop.run_until_complete(ex.create_selection(v3, index=1)) result0 = loop.run_until_complete(v4.compute()) result1 = loop.run_until_complete(v5.compute()) self.assertEqual(result0.numpy(), 10) self.assertEqual(result1.numpy(), 20)
def test_with_tuples(self): ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor()) loop = asyncio.get_event_loop() @computations.tf_computation(tf.int32, tf.int32) def add_numbers(x, y): return x + y @computations.federated_computation def comp(): return add_numbers(10, 20) v1 = loop.run_until_complete(ex.create_value(comp)) v2 = loop.run_until_complete(ex.create_call(v1)) result = loop.run_until_complete(v2.compute()) self.assertEqual(result.numpy(), 30)
def test_with_one_arg_tf_comp_in_one_arg_fed_comp(self): ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor()) loop = asyncio.get_event_loop() @computations.tf_computation(tf.int32) def add_one(x): return x + 1 @computations.federated_computation(tf.int32) def comp(x): return add_one(add_one(x)) v1 = loop.run_until_complete(ex.create_value(comp)) v2 = loop.run_until_complete(ex.create_value(10, tf.int32)) v3 = loop.run_until_complete(ex.create_call(v1, v2)) result = loop.run_until_complete(v3.compute()) self.assertEqual(result.numpy(), 12)
def test_with_functional_parameter(self): ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor()) loop = asyncio.get_event_loop() @computations.tf_computation(tf.int32) def add_one(x): return x + 1 @computations.federated_computation( computation_types.FunctionType(tf.int32, tf.int32), tf.int32) def comp(f, x): return f(f(x)) v1 = loop.run_until_complete(ex.create_value(comp)) v2 = loop.run_until_complete(ex.create_value(add_one)) v3 = loop.run_until_complete(ex.create_value(10, tf.int32)) v4 = loop.run_until_complete( ex.create_tuple( anonymous_tuple.AnonymousTuple([(None, v2), (None, v3)]))) v5 = loop.run_until_complete(ex.create_call(v1, v4)) result = loop.run_until_complete(v5.compute()) self.assertEqual(result.numpy(), 12)
def test_with_one_arg_tf_comp_in_two_arg_fed_comp(self): ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor()) loop = asyncio.get_event_loop() @computations.tf_computation(tf.int32, tf.int32) def add_numbers(x, y): return x + y @computations.federated_computation(tf.int32, tf.int32) def comp(x, y): return add_numbers(x, x), add_numbers(x, y), add_numbers(y, y) v1 = loop.run_until_complete(ex.create_value(comp)) v2 = loop.run_until_complete(ex.create_value(10, tf.int32)) v3 = loop.run_until_complete(ex.create_value(20, tf.int32)) v4 = loop.run_until_complete( ex.create_tuple( anonymous_tuple.AnonymousTuple([(None, v2), (None, v3)]))) v5 = loop.run_until_complete(ex.create_call(v1, v4)) result = loop.run_until_complete(v5.compute()) self.assertEqual( str(anonymous_tuple.map_structure(lambda x: x.numpy(), result)), '<20,30,40>')
def test_with_federated_map(self): eager_ex = eager_executor.EagerExecutor() federated_ex = federated_executor.FederatedExecutor({ None: eager_ex, placement_literals.SERVER: eager_ex }) ex = lambda_executor.LambdaExecutor(federated_ex) loop = asyncio.get_event_loop() @computations.tf_computation(tf.int32) def add_one(x): return x + 1 @computations.federated_computation(type_factory.at_server(tf.int32)) def comp(x): return intrinsics.federated_map(add_one, x) v1 = loop.run_until_complete(ex.create_value(comp)) v2 = loop.run_until_complete( ex.create_value(10, type_factory.at_server(tf.int32))) v3 = loop.run_until_complete(ex.create_call(v1, v2)) result = loop.run_until_complete(v3.compute()) self.assertEqual(result.numpy(), 11)
def _create_bottom_stack(): return lambda_executor.LambdaExecutor( caching_executor.CachingExecutor( concurrent_executor.ConcurrentExecutor( eager_executor.EagerExecutor())))
def _make_executor_and_tracer_for_test(support_lambdas=False): tracer = executor_test_utils.TracingExecutor(eager_executor.EagerExecutor()) ex = caching_executor.CachingExecutor(tracer) if support_lambdas: ex = lambda_executor.LambdaExecutor(caching_executor.CachingExecutor(ex)) return ex, tracer
def _complete_stack(ex): return lambda_executor.LambdaExecutor( caching_executor.CachingExecutor( concurrent_executor.ConcurrentExecutor(ex)))
def test_runs_tf(self): executor_test_utils.test_runs_tf( self, lambda_executor.LambdaExecutor(eager_executor.EagerExecutor()))
def _create_middle_stack(children): return lambda_executor.LambdaExecutor( caching_executor.CachingExecutor( composite_executor.CompositeExecutor(_create_bottom_stack(), children)))