def test_context(rpc_mode='REQUEST_REPLY'): port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) server = grpc.server(server_pool) server.add_insecure_port('[::]:{}'.format(port)) target_executor = executor_stacks.local_executor_factory( num_clients=3).create_executor({}) tracer = executor_test_utils.TracingExecutor(target_executor) service = executor_service.ExecutorService(tracer) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() channel = grpc.insecure_channel('localhost:{}'.format(port)) remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode) executor = lambda_executor.LambdaExecutor(remote_exec) set_default_executor.set_default_executor( executor_factory.ExecutorFactoryImpl(lambda _: executor)) try: yield collections.namedtuple('_', 'executor tracer')(executor, tracer) finally: set_default_executor.set_default_executor() try: channel.close() except AttributeError: pass # Public gRPC channel doesn't support close() finally: server.stop(None)
def _get_losses_before_and_after_training_single_batch(ex): set_default_executor.set_default_executor(ex) model = _mnist_initial_model losses = [_mnist_batch_loss(model, _mnist_sample_batch)] for _ in range(20): model = _mnist_batch_train(model, _mnist_sample_batch) losses.append(_mnist_batch_loss(model, _mnist_sample_batch)) return losses
def setUp(self): super(CompositeExecutorTest, self).setUp() set_default_executor.set_default_executor( _create_middle_stack([ _create_middle_stack( [_create_worker_stack() for _ in range(3)]), _create_middle_stack( [_create_worker_stack() for _ in range(3)]) ]))
def setUp(self): super().setUp() # 2 clients per worker stack * 3 worker stacks * 2 middle stacks self._num_clients = 12 set_default_executor.set_default_executor( _create_middle_stack([ _create_middle_stack([_create_worker_stack() for _ in range(3)]), _create_middle_stack([_create_worker_stack() for _ in range(3)]) ]))
def test_with_num_clients_larger_than_fanout(self): set_default_executor.set_default_executor( executor_stacks.create_local_executor(max_fanout=3)) @computations.federated_computation(type_factory.at_clients(tf.int32)) def foo(x): return intrinsics.federated_sum(x) self.assertEqual(foo([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 55)
def test_with_no_args(self): set_default_executor.set_default_executor( executor_stacks.create_local_executor()) @computations.tf_computation def foo(): return tf.constant(10) self.assertEqual(foo(), 10) set_default_executor.set_default_executor()
def test_basic_functionality(self): @computations.tf_computation(computation_types.SequenceType(tf.int32)) def comp(ds): return ds.take(5).reduce(np.int32(0), lambda x, y: x + y) set_default_executor.set_default_executor(eager_executor.EagerExecutor()) ds = tf.data.Dataset.range(1).map(lambda x: tf.constant(5)).repeat() v = comp(ds) self.assertEqual(v, 25) set_default_executor.set_default_executor() self.assertIn('ExecutionContext', str(type(context_stack_impl.context_stack.current).__name__))
def setUp(self): super().setUp() # 2 clients per worker stack * 3 worker stacks * 2 middle stacks self._num_clients = 12 def _stack_fn(x): del x # Unused return _create_middle_stack([ _create_middle_stack( [_create_worker_stack() for _ in range(3)]), _create_middle_stack( [_create_worker_stack() for _ in range(3)]) ]) set_default_executor.set_default_executor( executor_factory.ExecutorFactoryImpl(_stack_fn))
def test_end_to_end(self): @computations.tf_computation(tf.int32) def add_one(x): return tf.add(x, 1) ex = concurrent_executor.ConcurrentExecutor( eager_executor.EagerExecutor()) set_default_executor.set_default_executor(ex) self.assertEqual(add_one(7), 8) # After this invocation, the ConcurrentExecutor has been closed, and needs # to be re-initialized. self.assertEqual(add_one(8), 9) set_default_executor.set_default_executor()
def test_executor_factory_raises_with_wrong_cardinalities(self): ex_factory = executor_stacks.local_executor_factory(num_clients=5) cardinalities = { placement_literals.SERVER: 1, None: 1, placement_literals.CLIENTS: 1 } with self.assertRaisesRegex(ValueError, 'construct an executor with 1 clients'): ex_factory.create_executor(cardinalities) @computations.tf_computation def foo(): return tf.constant(10) self.assertEqual(foo(), 10) set_default_executor.set_default_executor()
def test_with_temperature_sensor_example(self): @computations.tf_computation( computation_types.SequenceType(tf.float32), tf.float32) def count_over(ds, t): return ds.reduce( np.float32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32)) @computations.tf_computation(computation_types.SequenceType(tf.float32)) def count_total(ds): return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0) @computations.federated_computation( type_factory.at_clients(computation_types.SequenceType(tf.float32)), type_factory.at_server(tf.float32)) def comp(temperatures, threshold): return intrinsics.federated_mean( intrinsics.federated_map( count_over, intrinsics.federated_zip( [temperatures, intrinsics.federated_broadcast(threshold)])), intrinsics.federated_map(count_total, temperatures)) set_default_executor.set_default_executor( executor_stacks.create_local_executor(3)) to_float = lambda x: tf.cast(x, tf.float32) temperatures = [ tf.data.Dataset.range(10).map(to_float), tf.data.Dataset.range(20).map(to_float), tf.data.Dataset.range(30).map(to_float) ] threshold = 15.0 result = comp(temperatures, threshold) self.assertAlmostEqual(result, 8.333, places=3) set_default_executor.set_default_executor( executor_stacks.create_local_executor()) to_float = lambda x: tf.cast(x, tf.float32) temperatures = [ tf.data.Dataset.range(10).map(to_float), tf.data.Dataset.range(20).map(to_float), tf.data.Dataset.range(30).map(to_float) ] threshold = 15.0 result = comp(temperatures, threshold) self.assertAlmostEqual(result, 8.333, places=3) set_default_executor.set_default_executor()
def tearDown(self): set_default_executor.set_default_executor(None) super().tearDown()
def tearDown(self): set_default_executor.set_default_executor(None) super(CompositeExecutorTest, self).tearDown()
def wrapped_fn(self, executor): set_default_executor.set_default_executor(executor) fn(self)
state, mean = federated_aggregate_test([1.0, 2.0, 3.0]) self.assertAlmostEqual(mean, 2.0) # (1 + 2 + 3) / (1 + 1 + 1) self.assertDictEqual(state._asdict(), {'call_count': 1}) def test_execute_with_explicit_weights(self): aggregate_fn = computation_utils.StatefulAggregateFn( initialize_fn=agg_initialize_fn, next_fn=agg_next_fn) @computations.federated_computation( computation_types.FederatedType(tf.float32, placements.CLIENTS), computation_types.FederatedType(tf.float32, placements.CLIENTS)) def federated_aggregate_test(args, weights): state = intrinsics.federated_value(aggregate_fn.initialize(), placements.SERVER) return aggregate_fn(state, args, weights) state, mean = federated_aggregate_test([1.0, 2.0, 3.0], [4.0, 1.0, 1.0]) self.assertAlmostEqual(mean, 1.5) # (1*4 + 2*1 + 3*1) / (4 + 1 + 1) self.assertDictEqual(state._asdict(), {'call_count': 1}) if __name__ == '__main__': tf.compat.v1.enable_v2_behavior() # NOTE: num_clients must be explicit here to correctly test the broadcast # behavior. Otherwise TFF will infer there are zero clients, which is an # error. set_default_executor.set_default_executor( executor_stacks.local_executor_factory(num_clients=3)) test.main()
def test_runs_tf(test_obj, executor): """Tests `executor` can run a minimal TF computation.""" py_typecheck.check_type(executor, executor_base.Executor) set_default_executor.set_default_executor( executor_factory.ExecutorFactoryImpl(lambda _: executor)) test_obj.assertEqual(_dummy_tf_computation(), 10)
def test_runs_tf(test_obj, executor): """Tests `executor` can run a minimal TF computation.""" set_default_executor.set_default_executor(executor) test_obj.assertEqual(_dummy_tf_computation(), 10)
_ = next(iter(x[1])) def test_fetch_value_with_empty_structured_dataset_and_tensors(self): def return_dataset(): ds1 = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([('a', [1, 1]), ('b', [1, 1])])) return [tf.constant([0., 0.]), ds1.batch(5).take(0)] executable_return_dataset = computation_impl.ComputationImpl( tensorflow_serialization.serialize_py_fn_as_tf_computation( return_dataset, None, context_stack_impl.context_stack)[0], context_stack_impl.context_stack) x = executable_return_dataset() self.assertAllEqual(x[0], [0., 0.]) self.assertEqual( tf.data.experimental.get_structure(x[1]), collections.OrderedDict([ ('a', tf.TensorSpec(shape=(None, ), dtype=tf.int32)), ('b', tf.TensorSpec(shape=(None, ), dtype=tf.int32)), ])) with self.assertRaises(StopIteration): _ = next(iter(x[1])) if __name__ == '__main__': # Use the local executor. set_default_executor.set_default_executor( executor_stacks.create_local_executor()) test.main()
state, mean = federated_aggregate_test([1.0, 2.0, 3.0]) self.assertAlmostEqual(mean, 2.0) # (1 + 2 + 3) / (1 + 1 + 1) self.assertDictEqual(state._asdict(), {'call_count': 1}) def test_execute_with_explicit_weights(self): aggregate_fn = computation_utils.StatefulAggregateFn( initialize_fn=agg_initialize_fn, next_fn=agg_next_fn) @computations.federated_computation( computation_types.FederatedType(tf.float32, placements.CLIENTS), computation_types.FederatedType(tf.float32, placements.CLIENTS)) def federated_aggregate_test(args, weights): state = intrinsics.federated_value(aggregate_fn.initialize(), placements.SERVER) return aggregate_fn(state, args, weights) state, mean = federated_aggregate_test([1.0, 2.0, 3.0], [4.0, 1.0, 1.0]) self.assertAlmostEqual(mean, 1.5) # (1*4 + 2*1 + 3*1) / (4 + 1 + 1) self.assertDictEqual(state._asdict(), {'call_count': 1}) if __name__ == '__main__': tf.compat.v1.enable_v2_behavior() # NOTE: num_clients must be explicit here to correctly test the broadcast # behavior. Otherwise TFF will infer there are zero clients, which is an # error. set_default_executor.set_default_executor( executor_stacks.create_local_executor(num_clients=3)) absltest.main()
def test_with_no_args(self): set_default_executor.set_default_executor( executor_stacks.local_executor_factory())