def test_context(rpc_mode='REQUEST_REPLY'): port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) server = grpc.server(server_pool) server.add_insecure_port('[::]:{}'.format(port)) target_executor = eager_executor.EagerExecutor() tracer = executor_test_utils.TracingExecutor(target_executor) service = executor_service.ExecutorService(tracer) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() channel = grpc.insecure_channel('localhost:{}'.format(port)) remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode) executor = lambda_executor.LambdaExecutor(remote_exec) set_default_executor.set_default_executor(executor) try: yield collections.namedtuple('_', 'executor tracer')(executor, tracer) finally: remote_exec.__del__() set_default_executor.set_default_executor() try: channel.close() except AttributeError: pass # Public gRPC channel doesn't support close() finally: server.stop(None)
def _get_losses_before_and_after_training_single_batch(ex): set_default_executor.set_default_executor(ex) model = _mnist_initial_model losses = [_mnist_batch_loss(model, _mnist_sample_batch)] for _ in range(20): model = _mnist_batch_train(model, _mnist_sample_batch) losses.append(_mnist_batch_loss(model, _mnist_sample_batch)) return losses
def test_with_num_clients_larger_than_fanout(self): set_default_executor.set_default_executor( executor_stacks.create_local_executor(max_fanout=3)) @computations.federated_computation(type_factory.at_clients(tf.int32)) def foo(x): return intrinsics.federated_sum(x) self.assertEqual(foo([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 55)
def setUp(self): super(CompositeExecutorTest, self).setUp() set_default_executor.set_default_executor( _create_middle_stack([ _create_middle_stack( [_create_worker_stack() for _ in range(3)]), _create_middle_stack( [_create_worker_stack() for _ in range(3)]) ]))
def test_with_no_args(self): set_default_executor.set_default_executor( executor_stacks.create_local_executor()) @computations.tf_computation def foo(): return tf.constant(10) self.assertEqual(foo(), 10) set_default_executor.set_default_executor()
def test_basic_functionality(self): @computations.tf_computation(computation_types.SequenceType(tf.int32)) def comp(ds): return ds.take(5).reduce(np.int32(0), lambda x, y: x + y) set_default_executor.set_default_executor( eager_executor.EagerExecutor()) ds = tf.data.Dataset.range(1).map(lambda x: tf.constant(5)).repeat() v = comp(ds) self.assertEqual(v, 25) set_default_executor.set_default_executor() self.assertIn( 'ReferenceExecutor', str(type(context_stack_impl.context_stack.current).__name__))
def test_context(): port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) server = grpc.server(server_pool) server.add_insecure_port('[::]:{}'.format(port)) target_executor = eager_executor.EagerExecutor() service = executor_service.ExecutorService(target_executor) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() channel = grpc.insecure_channel('localhost:{}'.format(port)) executor = remote_executor.RemoteExecutor(channel) set_default_executor.set_default_executor(executor) yield executor set_default_executor.set_default_executor() try: channel.close() except AttributeError: del channel server.stop(None)
def test_with_temperature_sensor_example(self): @computations.tf_computation(computation_types.SequenceType( tf.float32), tf.float32) def count_over(ds, t): return ds.reduce( np.float32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32)) @computations.tf_computation(computation_types.SequenceType(tf.float32) ) def count_total(ds): return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0) @computations.federated_computation( type_constructors.at_clients( computation_types.SequenceType(tf.float32)), type_constructors.at_server(tf.float32)) def comp(temperatures, threshold): return intrinsics.federated_mean( intrinsics.federated_map( count_over, intrinsics.federated_zip([ temperatures, intrinsics.federated_broadcast(threshold) ])), intrinsics.federated_map(count_total, temperatures)) num_clients = 3 set_default_executor.set_default_executor( executor_stacks.create_local_executor(num_clients)) to_float = lambda x: tf.cast(x, tf.float32) temperatures = [ tf.data.Dataset.range(10).map(to_float), tf.data.Dataset.range(20).map(to_float), tf.data.Dataset.range(30).map(to_float) ] threshold = 15.0 result = comp(temperatures, threshold) self.assertAlmostEqual(result, 8.333, places=3) set_default_executor.set_default_executor()
def test_with_incomplete_temperature_sensor_example(self): @computations.federated_computation( type_constructors.at_clients( computation_types.SequenceType(tf.float32)), type_constructors.at_server(tf.float32)) def comp(temperatures, threshold): @computations.tf_computation( computation_types.SequenceType(tf.float32), tf.float32) def count(ds, t): return ds.reduce( np.int32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.int32)) return intrinsics.federated_map( count, intrinsics.federated_zip( [temperatures, intrinsics.federated_broadcast(threshold)])) num_clients = 10 set_default_executor.set_default_executor( executor_stacks.create_local_executor(num_clients)) temperatures = [ tf.data.Dataset.range(1000).map(lambda x: tf.cast(x, tf.float32)) for _ in range(num_clients) ] threshold = 100.0 result = comp(temperatures, threshold) self.assertCountEqual([x.numpy() for x in result], [899 for _ in range(num_clients)]) set_default_executor.set_default_executor()
def wrapped_fn(self, executor): set_default_executor.set_default_executor(executor) fn(self)
def tearDown(self): set_default_executor.set_default_executor(None) super(CompositeExecutorTest, self).tearDown()